mirror of
https://github.com/ollama/ollama.git
synced 2026-02-25 03:26:46 -05:00
Compare commits
1 Commits
main
...
pdevine/me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e19bfd20d |
@@ -296,13 +296,19 @@ func normalizeQuantType(quantize string) string {
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
isStackedExpert := strings.Contains(name, ".mlp.experts.gate_up_proj") || strings.Contains(name, ".mlp.experts.down_proj")
|
||||
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
if !isStackedExpert && !ShouldQuantize(name, "") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
// Quantize 2D linear tensors by default. qwen3.5 stacked expert tensors are
|
||||
// also eligible even though they are stored as 3D [experts, out, in].
|
||||
if !isStackedExpert && len(shape) != 2 {
|
||||
return ""
|
||||
}
|
||||
if isStackedExpert && len(shape) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -372,9 +378,10 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
}
|
||||
|
||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
||||
// Captures: model.layers.{L}.mlp.experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
||||
// Matches nested and non-nested LLM prefixes and both per-expert ".weight"
|
||||
// tensors and qwen3.5 stacked expert tensors without ".weight".
|
||||
// Captures: model(.language_model(.model)?).layers.{L}.mlp.experts or .shared_experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model(?:\.language_model(?:\.model)?)?\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*(?:\.weight)?$`)
|
||||
|
||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||
// For example:
|
||||
|
||||
@@ -557,6 +557,9 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
// qwen3.5 stacked expert tensors are an exception: 3D [experts, out, in]
|
||||
{"qwen3.5 stacked gate_up_proj", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int4", true},
|
||||
{"qwen3.5 stacked down_proj", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int4", true},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
@@ -586,6 +589,42 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_StackedExperts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
shape: []int32{256, 1024, 2048},
|
||||
quantize: "int4",
|
||||
want: "int4",
|
||||
},
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.down_proj",
|
||||
shape: []int32{256, 2048, 512},
|
||||
quantize: "int4",
|
||||
want: "int8",
|
||||
},
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
shape: []int32{256, 1024, 2050}, // not divisible by 32
|
||||
quantize: "int4",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetTensorQuantization(tt.name, tt.shape, tt.quantize); got != tt.want {
|
||||
t.Fatalf("GetTensorQuantization(%q, %v, %q) = %q, want %q", tt.name, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpertGroupPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -595,10 +634,18 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.gate_up_proj", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.down_proj", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.gate_up_proj", "model.language_model.model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.down_proj", "model.language_model.model.layers.0.mlp.experts"},
|
||||
|
||||
// Shared expert tensors should return their own group prefix
|
||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||
{"model.language_model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.layers.1.mlp.shared_experts"},
|
||||
{"model.language_model.model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.model.layers.1.mlp.shared_experts"},
|
||||
|
||||
// Non-expert tensors should return empty string
|
||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||
|
||||
@@ -5,62 +5,341 @@ package mlxrunner
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
const defaultPromptCacheBranches = 1
|
||||
|
||||
// HybridCacheEntry stores a single prompt branch with mixed cache types
|
||||
// (e.g. KV + recurrent caches) and coordinates shared operations across them.
|
||||
type HybridCacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
}
|
||||
|
||||
// CacheEntry is kept as an alias for the current single-entry runner path.
|
||||
// Future multi-entry cache stores should prefer HybridCacheEntry directly.
|
||||
type CacheEntry = HybridCacheEntry
|
||||
|
||||
func promptCacheBranchLimit() int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultPromptCacheBranches
|
||||
}
|
||||
|
||||
func cloneTokens(tokens []int32) []int32 {
|
||||
out := make([]int32, len(tokens))
|
||||
copy(out, tokens)
|
||||
return out
|
||||
}
|
||||
|
||||
func equalTokens(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Runner) cacheStore() []*HybridCacheEntry {
|
||||
if len(r.caches) == 0 && r.cache != nil {
|
||||
r.caches = []*HybridCacheEntry{r.cache}
|
||||
}
|
||||
if r.cache == nil && len(r.caches) > 0 {
|
||||
r.cache = r.caches[0]
|
||||
}
|
||||
return r.caches
|
||||
}
|
||||
|
||||
func (r *Runner) setCacheStore(entries []*HybridCacheEntry) {
|
||||
r.caches = entries
|
||||
if len(entries) == 0 {
|
||||
r.cache = nil
|
||||
return
|
||||
}
|
||||
r.cache = entries[0]
|
||||
}
|
||||
|
||||
func (r *Runner) touchCacheEntry(idx int) {
|
||||
if idx <= 0 || idx >= len(r.caches) {
|
||||
return
|
||||
}
|
||||
e := r.caches[idx]
|
||||
copy(r.caches[1:idx+1], r.caches[:idx])
|
||||
r.caches[0] = e
|
||||
r.cache = r.caches[0]
|
||||
}
|
||||
|
||||
func (r *Runner) bestCacheEntry(tokens []int32) (idx int, prefix int) {
|
||||
bestIdx, bestPrefix := -1, 0
|
||||
for i, e := range r.cacheStore() {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
p := e.PrefixLen(tokens)
|
||||
if p > bestPrefix {
|
||||
bestIdx, bestPrefix = i, p
|
||||
}
|
||||
}
|
||||
return bestIdx, bestPrefix
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) PrefixLen(tokens []int32) int {
|
||||
if e == nil {
|
||||
return 0
|
||||
}
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(e.Tokens) && tokens[prefix] == e.Tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) Free() {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c != nil {
|
||||
c.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) cachesSlice() []cache.Cache {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Caches
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) cachesCanTrim() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if !c.CanTrim() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) TrimToPrefix(prefix int) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || !c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
trim := c.Offset() - prefix
|
||||
if trim > 0 {
|
||||
c.Trim(trim)
|
||||
}
|
||||
}
|
||||
if prefix < len(e.Tokens) {
|
||||
e.Tokens = e.Tokens[:prefix]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) RestoreToPrefix(target int) (int, bool) {
|
||||
if e == nil {
|
||||
return 0, false
|
||||
}
|
||||
restorePos := -1
|
||||
sawNonTrimmable := false
|
||||
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
sawNonTrimmable = true
|
||||
|
||||
restorer, ok := c.(cache.CheckpointRestorer)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
pos, ok := restorer.BestCheckpoint(target)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
if restorePos < 0 {
|
||||
restorePos = pos
|
||||
continue
|
||||
}
|
||||
if pos != restorePos {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
if !sawNonTrimmable || restorePos < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
e.TrimToPrefix(restorePos)
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
restorer, ok := c.(cache.CheckpointRestorer)
|
||||
if !ok || !restorer.RestoreCheckpoint(restorePos) {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
if restorePos < len(e.Tokens) {
|
||||
e.Tokens = e.Tokens[:restorePos]
|
||||
}
|
||||
return restorePos, true
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
entries := r.cacheStore()
|
||||
if len(entries) == 0 {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
// Find longest common prefix
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
prefix++
|
||||
branchLimit := promptCacheBranchLimit()
|
||||
idx, prefix := r.bestCacheEntry(tokens)
|
||||
if idx < 0 {
|
||||
if branchLimit <= 1 && len(entries) == 1 && entries[0] != nil {
|
||||
entries[0].Free()
|
||||
r.setCacheStore(nil)
|
||||
}
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
if idx > 0 {
|
||||
r.touchCacheEntry(idx)
|
||||
}
|
||||
base := r.cache
|
||||
if base == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
working := base
|
||||
forked := false
|
||||
if branchLimit > 1 && prefix > 0 {
|
||||
working = base.Clone()
|
||||
forked = true
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
if !forked && branchLimit <= 1 {
|
||||
base.Free()
|
||||
r.setCacheStore(nil)
|
||||
}
|
||||
r.cache = nil
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
case prefix < len(working.Tokens):
|
||||
if !working.cachesCanTrim() {
|
||||
if restorePos, ok := working.RestoreToPrefix(prefix); ok {
|
||||
slog.Info("Cache restore", "total", len(tokens), "matched", prefix, "restored", restorePos, "left", len(tokens[restorePos:]))
|
||||
return working.cachesSlice(), tokens[restorePos:]
|
||||
}
|
||||
|
||||
if forked {
|
||||
working.Free()
|
||||
} else if branchLimit <= 1 {
|
||||
base.Free()
|
||||
r.setCacheStore(nil)
|
||||
}
|
||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||
return nil, tokens
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
working.TrimToPrefix(prefix)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
return working.cachesSlice(), tokens[prefix:]
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
entry := &HybridCacheEntry{
|
||||
Tokens: cloneTokens(tokens),
|
||||
Caches: caches,
|
||||
}
|
||||
|
||||
branchLimit := promptCacheBranchLimit()
|
||||
if branchLimit <= 1 {
|
||||
r.setCacheStore([]*HybridCacheEntry{entry})
|
||||
return
|
||||
}
|
||||
|
||||
entries := r.cacheStore()
|
||||
// Replace any exact-token duplicate branch with the new result.
|
||||
for i := 0; i < len(entries); i++ {
|
||||
if entries[i] == nil || !equalTokens(entries[i].Tokens, entry.Tokens) {
|
||||
continue
|
||||
}
|
||||
entries[i].Free()
|
||||
entries = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
|
||||
entries = append([]*HybridCacheEntry{entry}, entries...)
|
||||
if len(entries) > branchLimit {
|
||||
for _, evicted := range entries[branchLimit:] {
|
||||
if evicted != nil {
|
||||
evicted.Free()
|
||||
}
|
||||
}
|
||||
entries = entries[:branchLimit]
|
||||
}
|
||||
r.setCacheStore(entries)
|
||||
}
|
||||
|
||||
func (c *HybridCacheEntry) Clone() *HybridCacheEntry {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
tokens := make([]int32, len(c.Tokens))
|
||||
copy(tokens, c.Tokens)
|
||||
caches := make([]cache.Cache, len(c.Caches))
|
||||
for i, cc := range c.Caches {
|
||||
if cc != nil {
|
||||
caches[i] = cc.Clone()
|
||||
}
|
||||
}
|
||||
return &HybridCacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
func (c *HybridCacheEntry) LogCache() {
|
||||
if c == nil || len(c.Caches) == 0 {
|
||||
return
|
||||
}
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
k, v := kv.State()
|
||||
if k == nil || v == nil {
|
||||
continue
|
||||
}
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
|
||||
40
x/mlxrunner/cache/cache.go
vendored
40
x/mlxrunner/cache/cache.go
vendored
@@ -3,13 +3,22 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func kvCacheGrowDebugEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Materialize() []*mlx.Array
|
||||
CanTrim() bool
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
@@ -17,6 +26,19 @@ type Cache interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
// CheckpointRecorder is an optional cache capability for recording recurrent
|
||||
// state snapshots at specific token positions.
|
||||
type CheckpointRecorder interface {
|
||||
RecordCheckpoint(pos int)
|
||||
}
|
||||
|
||||
// CheckpointRestorer is an optional cache capability for restoring recurrent
|
||||
// state to a previously recorded checkpoint.
|
||||
type CheckpointRestorer interface {
|
||||
BestCheckpoint(target int) (pos int, ok bool)
|
||||
RestoreCheckpoint(pos int) bool
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
@@ -49,6 +71,9 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
@@ -67,6 +92,19 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.keys != nil && c.keys.Valid() {
|
||||
out = append(out, c.keys)
|
||||
}
|
||||
if c.values != nil && c.values.Valid() {
|
||||
out = append(out, c.values)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *KVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
@@ -190,6 +228,8 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestKVCacheGrowDebugEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
|
||||
if !kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
519
x/mlxrunner/cache/recurrent.go
vendored
Normal file
519
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,519 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRecurrentCheckpointCount = 32
|
||||
defaultRecurrentCheckpointInterval = 128
|
||||
defaultRecurrentCheckpointMinPos = 16
|
||||
)
|
||||
|
||||
type recurrentCheckpoint struct {
|
||||
pos int
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
}
|
||||
|
||||
type recurrentSlot struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
|
||||
checkpoints []recurrentCheckpoint
|
||||
checkpointSize int
|
||||
checkpointNext int
|
||||
checkpointLastPos int
|
||||
|
||||
refs int
|
||||
}
|
||||
|
||||
func getenvInt(name string, def int) int {
|
||||
if v := os.Getenv(name); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func recurrentCheckpointConfig() (count, interval, minPos int) {
|
||||
count = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINTS", defaultRecurrentCheckpointCount)
|
||||
interval = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", defaultRecurrentCheckpointInterval)
|
||||
minPos = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", defaultRecurrentCheckpointMinPos)
|
||||
|
||||
if count < 0 {
|
||||
count = 0
|
||||
}
|
||||
if interval < 0 {
|
||||
interval = 0
|
||||
}
|
||||
if minPos < 0 {
|
||||
minPos = 0
|
||||
}
|
||||
return count, interval, minPos
|
||||
}
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
slot *recurrentSlot
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
|
||||
checkpointCount int
|
||||
checkpointInterval int
|
||||
checkpointMinPos int
|
||||
}
|
||||
|
||||
func newRecurrentSlot(checkpointCount int) *recurrentSlot {
|
||||
s := &recurrentSlot{
|
||||
refs: 1,
|
||||
checkpointLastPos: -1,
|
||||
}
|
||||
if checkpointCount > 0 {
|
||||
s.checkpoints = make([]recurrentCheckpoint, checkpointCount)
|
||||
for i := range s.checkpoints {
|
||||
s.checkpoints[i].pos = -1
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func retainRecurrentSlot(s *recurrentSlot) *recurrentSlot {
|
||||
if s != nil {
|
||||
s.refs++
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
// Break dependency chains so recurrent state does not retain the full
|
||||
// per-token compute graph over time.
|
||||
snap := mlx.Snapshot(v)
|
||||
mlx.Eval(snap)
|
||||
|
||||
old := *dst
|
||||
*dst = snap
|
||||
mlx.Pin(snap)
|
||||
|
||||
// Release previous cached state root, then recursively free the transient
|
||||
// incoming graph root now that a detached snapshot is retained in cache.
|
||||
if old != nil && old != snap {
|
||||
mlx.Release(old)
|
||||
}
|
||||
if v != snap && v != old {
|
||||
mlx.Release(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
old := *dst
|
||||
*dst = v
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Release(old)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
root = mlx.Contiguous(v, false)
|
||||
}
|
||||
detached := mlx.Detach(root)
|
||||
|
||||
old := *dst
|
||||
*dst = detached
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Release(old)
|
||||
}
|
||||
|
||||
// Intentionally do not force-release root/v here. In the fast path, the detached
|
||||
// handle aliases the same MLX value and may still be lazily computed. Releasing the
|
||||
// source handles can invalidate the cached state before the next eval/sweep point.
|
||||
}
|
||||
|
||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
snap := mlx.Snapshot(a)
|
||||
mlx.Eval(snap)
|
||||
mlx.Pin(snap)
|
||||
return snap
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
count, interval, minPos := recurrentCheckpointConfig()
|
||||
c := &RecurrentCache{
|
||||
slot: newRecurrentSlot(count),
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
checkpointCount: count,
|
||||
checkpointInterval: interval,
|
||||
checkpointMinPos: minPos,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func clonePinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
clone := a.Clone()
|
||||
mlx.Pin(clone)
|
||||
return clone
|
||||
}
|
||||
|
||||
func releaseCheckpointEntry(e *recurrentCheckpoint) {
|
||||
mlx.Release(e.convState, e.deltaState)
|
||||
e.convState, e.deltaState = nil, nil
|
||||
e.pos = -1
|
||||
}
|
||||
|
||||
func releaseRecurrentSlot(s *recurrentSlot) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.refs--
|
||||
if s.refs > 0 {
|
||||
return
|
||||
}
|
||||
mlx.Release(s.convState, s.deltaState)
|
||||
s.convState, s.deltaState = nil, nil
|
||||
for i := range s.checkpoints {
|
||||
releaseCheckpointEntry(&s.checkpoints[i])
|
||||
}
|
||||
s.checkpointSize = 0
|
||||
s.checkpointNext = 0
|
||||
s.checkpointLastPos = -1
|
||||
}
|
||||
|
||||
func cloneRecurrentSlot(src *recurrentSlot) *recurrentSlot {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := &recurrentSlot{
|
||||
checkpointSize: src.checkpointSize,
|
||||
checkpointNext: src.checkpointNext,
|
||||
checkpointLastPos: src.checkpointLastPos,
|
||||
refs: 1,
|
||||
}
|
||||
if src.convState != nil && src.convState.Valid() {
|
||||
dst.convState = snapshotPinned(src.convState)
|
||||
}
|
||||
if src.deltaState != nil && src.deltaState.Valid() {
|
||||
dst.deltaState = snapshotPinned(src.deltaState)
|
||||
}
|
||||
if len(src.checkpoints) > 0 {
|
||||
dst.checkpoints = make([]recurrentCheckpoint, len(src.checkpoints))
|
||||
for i := range src.checkpoints {
|
||||
dst.checkpoints[i].pos = src.checkpoints[i].pos
|
||||
if src.checkpoints[i].pos < 0 {
|
||||
continue
|
||||
}
|
||||
dst.checkpoints[i].convState = snapshotPinned(src.checkpoints[i].convState)
|
||||
dst.checkpoints[i].deltaState = snapshotPinned(src.checkpoints[i].deltaState)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) slotOrInit() *recurrentSlot {
|
||||
if c.slot == nil {
|
||||
c.slot = newRecurrentSlot(c.checkpointCount)
|
||||
}
|
||||
return c.slot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensureWritableSlot() *recurrentSlot {
|
||||
s := c.slotOrInit()
|
||||
if s.refs <= 1 {
|
||||
return s
|
||||
}
|
||||
c.slot = cloneRecurrentSlot(s)
|
||||
s.refs--
|
||||
return c.slot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) pruneCheckpointsAfter(pos int) {
|
||||
s := c.ensureWritableSlot()
|
||||
if len(s.checkpoints) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
last := -1
|
||||
minPos := int(^uint(0) >> 1)
|
||||
minIdx := 0
|
||||
for i := range s.checkpoints {
|
||||
e := &s.checkpoints[i]
|
||||
if e.pos > pos {
|
||||
releaseCheckpointEntry(e)
|
||||
}
|
||||
if e.pos >= 0 {
|
||||
size++
|
||||
if e.pos > last {
|
||||
last = e.pos
|
||||
}
|
||||
if e.pos < minPos {
|
||||
minPos = e.pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.checkpointSize = size
|
||||
s.checkpointLastPos = last
|
||||
if size == 0 {
|
||||
s.checkpointNext = 0
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.checkpointNext = next
|
||||
return
|
||||
}
|
||||
s.checkpointNext = minIdx
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
s := c.slotOrInit()
|
||||
needConv := s.convState == nil || s.convState.DType() != dtype ||
|
||||
s.convState.Dim(0) != batch || s.convState.Dim(1) != c.convTail || s.convState.Dim(2) != c.convDim
|
||||
needDelta := s.deltaState == nil || s.deltaState.DType() != dtype ||
|
||||
s.deltaState.Dim(0) != batch || s.deltaState.Dim(1) != c.numVHeads || s.deltaState.Dim(2) != c.headVDim || s.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
return
|
||||
}
|
||||
|
||||
s = c.ensureWritableSlot()
|
||||
if needConv {
|
||||
c.setStateRaw(&s.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
|
||||
if needDelta {
|
||||
c.setStateRaw(&s.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.slotOrInit().convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateMaterialized(&s.convState, v)
|
||||
}
|
||||
|
||||
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point. The conv-state input is usually a slice view, so request a
|
||||
// compact contiguous copy to avoid pinning the whole source buffer.
|
||||
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateDetached(&s.convState, v, true)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.slotOrInit().deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateMaterialized(&s.deltaState, v)
|
||||
}
|
||||
|
||||
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point.
|
||||
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateDetached(&s.deltaState, v, false)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.ensure(1, mlx.DTypeFloat32)
|
||||
s := c.slotOrInit()
|
||||
return s.convState, s.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
s := c.slot
|
||||
if s == nil {
|
||||
return out
|
||||
}
|
||||
if s.convState != nil && s.convState.Valid() {
|
||||
out = append(out, s.convState)
|
||||
}
|
||||
if s.deltaState != nil && s.deltaState.Valid() {
|
||||
out = append(out, s.deltaState)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) RecordCheckpoint(pos int) {
|
||||
s := c.slot
|
||||
if s == nil || len(s.checkpoints) == 0 || pos <= 0 || pos < c.checkpointMinPos {
|
||||
return
|
||||
}
|
||||
if c.offset != pos {
|
||||
// Checkpoints are keyed by logical token position. Ignore callers with a
|
||||
// mismatched position to avoid restoring inconsistent recurrent state.
|
||||
return
|
||||
}
|
||||
if s.convState == nil || s.deltaState == nil || !s.convState.Valid() || !s.deltaState.Valid() {
|
||||
return
|
||||
}
|
||||
if s.checkpointLastPos == pos {
|
||||
return
|
||||
}
|
||||
if s.checkpointLastPos >= 0 && c.checkpointInterval > 0 && pos-s.checkpointLastPos < c.checkpointInterval {
|
||||
return
|
||||
}
|
||||
if s.refs > 1 {
|
||||
s = c.ensureWritableSlot()
|
||||
}
|
||||
|
||||
idx := s.checkpointNext
|
||||
e := &s.checkpoints[idx]
|
||||
releaseCheckpointEntry(e)
|
||||
e.pos = pos
|
||||
e.convState = clonePinned(s.convState)
|
||||
e.deltaState = clonePinned(s.deltaState)
|
||||
|
||||
s.checkpointNext = (idx + 1) % len(s.checkpoints)
|
||||
if s.checkpointSize < len(s.checkpoints) {
|
||||
s.checkpointSize++
|
||||
}
|
||||
s.checkpointLastPos = pos
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) BestCheckpoint(target int) (pos int, ok bool) {
|
||||
s := c.slot
|
||||
if s == nil {
|
||||
return 0, false
|
||||
}
|
||||
best := -1
|
||||
for i := range s.checkpoints {
|
||||
pos := s.checkpoints[i].pos
|
||||
if pos < 0 || pos > target {
|
||||
continue
|
||||
}
|
||||
if pos > best {
|
||||
best = pos
|
||||
}
|
||||
}
|
||||
if best < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return best, true
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) RestoreCheckpoint(pos int) bool {
|
||||
if pos < 0 {
|
||||
return false
|
||||
}
|
||||
s := c.ensureWritableSlot()
|
||||
for i := range s.checkpoints {
|
||||
e := &s.checkpoints[i]
|
||||
if e.pos != pos {
|
||||
continue
|
||||
}
|
||||
if e.convState == nil || e.deltaState == nil || !e.convState.Valid() || !e.deltaState.Valid() {
|
||||
return false
|
||||
}
|
||||
|
||||
c.setStateRaw(&s.convState, e.convState.Clone())
|
||||
c.setStateRaw(&s.deltaState, e.deltaState.Clone())
|
||||
c.offset = pos
|
||||
c.pruneCheckpointsAfter(pos)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
// Recurrent state is not directly trimmable; callers should use
|
||||
// checkpoint-based restore instead.
|
||||
_ = n
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
slot: retainRecurrentSlot(c.slotOrInit()),
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
checkpointCount: c.checkpointCount,
|
||||
checkpointInterval: c.checkpointInterval,
|
||||
checkpointMinPos: c.checkpointMinPos,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
releaseRecurrentSlot(c.slot)
|
||||
c.slot = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
125
x/mlxrunner/cache/recurrent_cow_test.go
vendored
Normal file
125
x/mlxrunner/cache/recurrent_cow_test.go
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func requireMLXRuntime(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX runtime unavailable: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRecurrentCache(t *testing.T) *RecurrentCache {
|
||||
t.Helper()
|
||||
requireMLXRuntime(t)
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINTS", "2")
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", "0")
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", "0")
|
||||
|
||||
c := NewRecurrentCache(2, 3, 1, 2, 2)
|
||||
_ = c.ConvState(1, mlx.DTypeFloat32)
|
||||
_ = c.DeltaState(1, mlx.DTypeFloat32)
|
||||
return c
|
||||
}
|
||||
|
||||
func TestRecurrentCacheCloneSharesSlotUntilMutation(t *testing.T) {
|
||||
c1 := newTestRecurrentCache(t)
|
||||
t.Cleanup(func() {
|
||||
c1.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
c1.Advance(8)
|
||||
c1.RecordCheckpoint(8)
|
||||
if got, ok := c1.BestCheckpoint(8); !ok || got != 8 {
|
||||
t.Fatalf("BestCheckpoint(8) = (%d, %v), want (8, true)", got, ok)
|
||||
}
|
||||
|
||||
c2 := c1.Clone().(*RecurrentCache)
|
||||
t.Cleanup(func() {
|
||||
c2.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
if c1.slot == nil || c2.slot == nil {
|
||||
t.Fatal("expected non-nil shared slots")
|
||||
}
|
||||
if c1.slot != c2.slot {
|
||||
t.Fatal("clone did not share recurrent slot")
|
||||
}
|
||||
if c1.slot.refs != 2 {
|
||||
t.Fatalf("shared slot refs = %d, want 2", c1.slot.refs)
|
||||
}
|
||||
|
||||
// Read access should not trigger a COW detach.
|
||||
_ = c2.ConvState(1, mlx.DTypeFloat32)
|
||||
_ = c2.DeltaState(1, mlx.DTypeFloat32)
|
||||
if c1.slot != c2.slot {
|
||||
t.Fatal("read access detached shared recurrent slot")
|
||||
}
|
||||
|
||||
// Mutating recurrent state should detach and deep-copy checkpoint metadata.
|
||||
c2.SetConvState(mlx.Zeros(mlx.DTypeFloat32, 1, 2, 3))
|
||||
|
||||
if c1.slot == c2.slot {
|
||||
t.Fatal("SetConvState did not detach shared recurrent slot")
|
||||
}
|
||||
if c1.slot.refs != 1 || c2.slot.refs != 1 {
|
||||
t.Fatalf("post-detach refs = (%d, %d), want (1, 1)", c1.slot.refs, c2.slot.refs)
|
||||
}
|
||||
if len(c1.slot.checkpoints) == 0 || len(c2.slot.checkpoints) == 0 {
|
||||
t.Fatal("expected checkpoint ring to be preserved on detach")
|
||||
}
|
||||
if c1.slot.checkpoints[0].pos != c2.slot.checkpoints[0].pos {
|
||||
t.Fatalf("checkpoint pos mismatch after detach: %d vs %d", c1.slot.checkpoints[0].pos, c2.slot.checkpoints[0].pos)
|
||||
}
|
||||
if c1.slot.checkpoints[0].pos != 8 {
|
||||
t.Fatalf("checkpoint pos = %d, want 8", c1.slot.checkpoints[0].pos)
|
||||
}
|
||||
if c1.slot.checkpoints[0].convState == c2.slot.checkpoints[0].convState {
|
||||
t.Fatal("checkpoint conv state was aliased after COW detach")
|
||||
}
|
||||
if c1.slot.checkpoints[0].deltaState == c2.slot.checkpoints[0].deltaState {
|
||||
t.Fatal("checkpoint delta state was aliased after COW detach")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentCacheFreeKeepsSharedCloneAlive(t *testing.T) {
|
||||
c1 := newTestRecurrentCache(t)
|
||||
c2 := c1.Clone().(*RecurrentCache)
|
||||
t.Cleanup(func() {
|
||||
c1.Free()
|
||||
c2.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
if c2.slot == nil || c2.slot.refs != 2 {
|
||||
t.Fatalf("shared clone refs = %d, want 2", func() int {
|
||||
if c2.slot == nil {
|
||||
return 0
|
||||
}
|
||||
return c2.slot.refs
|
||||
}())
|
||||
}
|
||||
|
||||
c1.Free()
|
||||
|
||||
if c2.slot == nil {
|
||||
t.Fatal("clone slot was cleared after freeing sibling clone")
|
||||
}
|
||||
if c2.slot.refs != 1 {
|
||||
t.Fatalf("clone slot refs after sibling Free = %d, want 1", c2.slot.refs)
|
||||
}
|
||||
if state := c2.ConvState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
|
||||
t.Fatal("clone conv state invalid after freeing sibling clone")
|
||||
}
|
||||
if state := c2.DeltaState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
|
||||
t.Fatal("clone delta state invalid after freeing sibling clone")
|
||||
}
|
||||
}
|
||||
339
x/mlxrunner/cache_policy_test.go
Normal file
339
x/mlxrunner/cache_policy_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
cachepkg "github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type fakeCache struct {
|
||||
canTrim bool
|
||||
trims []int
|
||||
freeCall int
|
||||
offset int
|
||||
}
|
||||
|
||||
func (f *fakeCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
|
||||
func (f *fakeCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
|
||||
func (f *fakeCache) Materialize() []*mlx.Array { return nil }
|
||||
func (f *fakeCache) CanTrim() bool { return f.canTrim }
|
||||
func (f *fakeCache) Trim(n int) int {
|
||||
f.trims = append(f.trims, n)
|
||||
f.offset -= n
|
||||
return n
|
||||
}
|
||||
func (f *fakeCache) Clone() cachepkg.Cache { return &fakeCache{canTrim: f.canTrim, offset: f.offset} }
|
||||
func (f *fakeCache) Free() { f.freeCall++ }
|
||||
func (f *fakeCache) Offset() int { return f.offset }
|
||||
func (f *fakeCache) Len() int { return f.offset }
|
||||
|
||||
type fakeCheckpointCache struct {
|
||||
fakeCache
|
||||
bestPos int
|
||||
hasCheckpoint bool
|
||||
restoreCalls []int
|
||||
restoreSuccess bool
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) BestCheckpoint(target int) (int, bool) {
|
||||
if !f.hasCheckpoint || f.bestPos > target {
|
||||
return 0, false
|
||||
}
|
||||
return f.bestPos, true
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) RestoreCheckpoint(pos int) bool {
|
||||
f.restoreCalls = append(f.restoreCalls, pos)
|
||||
if !f.restoreSuccess || pos != f.bestPos {
|
||||
return false
|
||||
}
|
||||
f.offset = pos
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) Clone() cachepkg.Cache {
|
||||
clone := *f
|
||||
clone.trims = nil
|
||||
clone.restoreCalls = nil
|
||||
return &clone
|
||||
}
|
||||
|
||||
func TestFindNearestCacheReusesAppendOnlyNonTrimmableCache(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: false, offset: 2}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4})
|
||||
|
||||
if len(gotCaches) != 1 || gotCaches[0] != fc {
|
||||
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
|
||||
}
|
||||
if want := []int32{3, 4}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", fc.freeCall)
|
||||
}
|
||||
if len(fc.trims) != 0 {
|
||||
t.Fatalf("trim calls = %v, want none", fc.trims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheDropsNonTrimmableCacheOnDivergence(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: false, offset: 4}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if gotCaches != nil {
|
||||
t.Fatalf("returned caches = %#v, want nil", gotCaches)
|
||||
}
|
||||
if want := []int32{1, 2, 9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", fc.freeCall)
|
||||
}
|
||||
if len(fc.trims) != 0 {
|
||||
t.Fatalf("trim calls = %v, want none", fc.trims)
|
||||
}
|
||||
if r.cache != nil {
|
||||
t.Fatal("runner cache should be cleared on non-trimmable divergence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheTrimsTrimmableCacheOnDivergence(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: true, offset: 4}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if len(gotCaches) != 1 || gotCaches[0] != fc {
|
||||
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
|
||||
}
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", fc.freeCall)
|
||||
}
|
||||
if want := []int{2}; !reflect.DeepEqual(fc.trims, want) {
|
||||
t.Fatalf("trim calls = %v, want %v", fc.trims, want)
|
||||
}
|
||||
if want := []int32{1, 2}; !reflect.DeepEqual(r.cache.Tokens, want) {
|
||||
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheRestoresCheckpointForNonTrimmableCaches(t *testing.T) {
|
||||
kv := &fakeCache{canTrim: true, offset: 7}
|
||||
rc1 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
rc2 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
|
||||
Caches: []cachepkg.Cache{kv, rc1, rc2},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
|
||||
|
||||
if len(gotCaches) != 3 {
|
||||
t.Fatalf("returned caches len = %d, want 3", len(gotCaches))
|
||||
}
|
||||
if want := []int32{8}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if want := []int{3}; !reflect.DeepEqual(kv.trims, want) {
|
||||
t.Fatalf("kv trim calls = %v, want %v", kv.trims, want)
|
||||
}
|
||||
if want := []int{4}; !reflect.DeepEqual(rc1.restoreCalls, want) {
|
||||
t.Fatalf("rc1 restore calls = %v, want %v", rc1.restoreCalls, want)
|
||||
}
|
||||
if want := []int{4}; !reflect.DeepEqual(rc2.restoreCalls, want) {
|
||||
t.Fatalf("rc2 restore calls = %v, want %v", rc2.restoreCalls, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.cache.Tokens, want) {
|
||||
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheDropsOnMismatchedCheckpointRestorePoints(t *testing.T) {
|
||||
rc1 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
rc2 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 3,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
|
||||
Caches: []cachepkg.Cache{rc1, rc2},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
|
||||
|
||||
if gotCaches != nil {
|
||||
t.Fatalf("returned caches = %#v, want nil", gotCaches)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4, 8}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if rc1.freeCall != 1 || rc2.freeCall != 1 {
|
||||
t.Fatalf("free calls = (%d,%d), want (1,1)", rc1.freeCall, rc2.freeCall)
|
||||
}
|
||||
if len(rc1.restoreCalls) != 0 || len(rc2.restoreCalls) != 0 {
|
||||
t.Fatalf("restore calls = (%v,%v), want none", rc1.restoreCalls, rc2.restoreCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheSelectsBestPrefixAcrossBranches(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "4")
|
||||
|
||||
short := &fakeCache{canTrim: true, offset: 2}
|
||||
long := &fakeCache{canTrim: true, offset: 4}
|
||||
shortEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2},
|
||||
Caches: []cachepkg.Cache{short},
|
||||
}
|
||||
longEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{long},
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: shortEntry,
|
||||
caches: []*HybridCacheEntry{shortEntry, longEntry},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 9})
|
||||
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if len(gotCaches) != 1 {
|
||||
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
|
||||
}
|
||||
if gotCaches[0] == long {
|
||||
t.Fatal("expected cloned cache in multi-branch mode, got original branch cache")
|
||||
}
|
||||
if r.cache != longEntry || r.caches[0] != longEntry {
|
||||
t.Fatal("best branch was not promoted to front of cache store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheForksBranchWithCloneWhenMultiBranchEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
|
||||
|
||||
base := &fakeCache{canTrim: true, offset: 4}
|
||||
baseEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{base},
|
||||
}
|
||||
r := &Runner{
|
||||
cache: baseEntry,
|
||||
caches: []*HybridCacheEntry{baseEntry},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if len(gotCaches) != 1 {
|
||||
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
|
||||
}
|
||||
clone, ok := gotCaches[0].(*fakeCache)
|
||||
if !ok {
|
||||
t.Fatalf("returned cache type = %T, want *fakeCache", gotCaches[0])
|
||||
}
|
||||
if clone == base {
|
||||
t.Fatal("expected branch fork to return a cloned cache")
|
||||
}
|
||||
if len(base.trims) != 0 {
|
||||
t.Fatalf("base branch trim calls = %v, want none", base.trims)
|
||||
}
|
||||
if want := []int{2}; !reflect.DeepEqual(clone.trims, want) {
|
||||
t.Fatalf("forked branch trim calls = %v, want %v", clone.trims, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(baseEntry.Tokens, want) {
|
||||
t.Fatalf("base entry tokens = %v, want %v", baseEntry.Tokens, want)
|
||||
}
|
||||
|
||||
r.InsertCache([]int32{1, 2, 9}, gotCaches)
|
||||
if len(r.caches) != 2 {
|
||||
t.Fatalf("cache store len = %d, want 2", len(r.caches))
|
||||
}
|
||||
if want := []int32{1, 2, 9}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
|
||||
t.Fatalf("new branch tokens = %v, want %v", r.caches[0].Tokens, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
|
||||
t.Fatalf("preserved branch tokens = %v, want %v", r.caches[1].Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertCacheEvictsOldestBranchWhenStoreFull(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
|
||||
|
||||
f1 := &fakeCache{canTrim: true, offset: 1}
|
||||
f2 := &fakeCache{canTrim: true, offset: 2}
|
||||
f3 := &fakeCache{canTrim: true, offset: 3}
|
||||
r := &Runner{}
|
||||
|
||||
r.InsertCache([]int32{1}, []cachepkg.Cache{f1})
|
||||
r.InsertCache([]int32{1, 2}, []cachepkg.Cache{f2})
|
||||
r.InsertCache([]int32{1, 2, 3}, []cachepkg.Cache{f3})
|
||||
|
||||
if len(r.caches) != 2 {
|
||||
t.Fatalf("cache store len = %d, want 2", len(r.caches))
|
||||
}
|
||||
if f1.freeCall != 1 {
|
||||
t.Fatalf("oldest branch free calls = %d, want 1", f1.freeCall)
|
||||
}
|
||||
if f2.freeCall != 0 || f3.freeCall != 0 {
|
||||
t.Fatalf("unexpected frees for retained branches: f2=%d f3=%d", f2.freeCall, f3.freeCall)
|
||||
}
|
||||
if want := []int32{1, 2, 3}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
|
||||
t.Fatalf("MRU tokens = %v, want %v", r.caches[0].Tokens, want)
|
||||
}
|
||||
if want := []int32{1, 2}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
|
||||
t.Fatalf("LRU tokens = %v, want %v", r.caches[1].Tokens, want)
|
||||
}
|
||||
}
|
||||
@@ -7,4 +7,6 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
@@ -267,3 +267,20 @@ func LogArrays() {
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
// Release forcibly frees arrays regardless of reference accounting.
|
||||
// Use only for arrays that are known to be unreachable by any live model state.
|
||||
func Release(s ...*Array) (n int) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.pinned = false
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
@@ -0,0 +1,275 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled atomic.Bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
dtype := q.DType()
|
||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
@@ -19,7 +19,8 @@ func doEval(outputs []*Array, async bool) {
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
// Callers may pass optional tensors (e.g. debug-only logprobs) as nil.
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -378,6 +429,32 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||
func Snapshot(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("SNAPSHOT")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// CollectReachable collects arrays from v and all transitive graph inputs.
|
||||
func CollectReachable(v any) []*Array {
|
||||
return Collect(v)
|
||||
}
|
||||
|
||||
// Detach returns a new Array handle that shares the same MLX value but does
|
||||
// not retain Go-side graph input references.
|
||||
func Detach(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("DETACH")
|
||||
C.mlx_array_set(&out.ctx, a.ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -14,6 +17,234 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
const defaultRecurrentMaterializeInterval = 64
|
||||
const defaultPipelineTimingEvery = 64
|
||||
|
||||
func prefillChunkSize(lowMemoryDecode bool) int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
if lowMemoryDecode {
|
||||
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
|
||||
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
|
||||
return 32
|
||||
}
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func hasRecurrentCaches(caches []cache.Cache) bool {
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := c.(*cache.RecurrentCache); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// recurrentMaterializeInterval controls periodic recurrent-cache materialization
|
||||
// during async decode. It exists to bound graph/handle growth when using fast
|
||||
// recurrent cache writes; it is primarily a memory/stability tuning knob, not a
|
||||
// throughput knob.
|
||||
func recurrentMaterializeInterval(lowMemoryDecode bool, hasRecurrent bool) int {
|
||||
if lowMemoryDecode || !hasRecurrent {
|
||||
return 0
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultRecurrentMaterializeInterval
|
||||
}
|
||||
|
||||
func mlxDebugMemoryEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
|
||||
}
|
||||
|
||||
// mlxPipelineTimingConfig controls runner-side decode pipeline timing logs. This
|
||||
// is diagnostic-only and intentionally separate from model-specific timing.
|
||||
func mlxPipelineTimingConfig() (enabled bool, every int) {
|
||||
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_TIMING"); ok {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
enabled = parsed
|
||||
}
|
||||
}
|
||||
if !enabled {
|
||||
return false, 0
|
||||
}
|
||||
every = defaultPipelineTimingEvery
|
||||
if v := os.Getenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
every = n
|
||||
}
|
||||
}
|
||||
return true, every
|
||||
}
|
||||
|
||||
// mlxComputeLogprobsEnabled restores the old decode-step logprob normalization
|
||||
// path for profiling/experiments. It is off by default because the MLX runner
|
||||
// does not currently populate Response.Logprobs.
|
||||
func mlxComputeLogprobsEnabled() bool {
|
||||
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS"); ok {
|
||||
if enabled, err := strconv.ParseBool(v); err == nil {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
// The MLX runner currently does not populate Response.Logprobs, so skip the
|
||||
// full-vocab logprob normalization path unless explicitly requested for
|
||||
// debugging/experiments.
|
||||
return false
|
||||
}
|
||||
|
||||
type pipelineTiming struct {
|
||||
every int
|
||||
|
||||
stepCalls int
|
||||
stepAsync int
|
||||
stepSync int
|
||||
sampleInts int
|
||||
|
||||
stepTotalDur time.Duration
|
||||
forwardDur time.Duration
|
||||
unembedDur time.Duration
|
||||
sliceDur time.Duration
|
||||
logprobsDur time.Duration
|
||||
sampleDur time.Duration
|
||||
pinSweepDur time.Duration
|
||||
asyncEvalDur time.Duration
|
||||
sampleIntDur time.Duration
|
||||
lastEmitCount int
|
||||
}
|
||||
|
||||
func newPipelineTiming() *pipelineTiming {
|
||||
enabled, every := mlxPipelineTimingConfig()
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
pt := &pipelineTiming{every: every}
|
||||
fmt.Fprintf(os.Stderr, "mlx pipeline timing: enabled every=%d\n", every)
|
||||
return pt
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) recordStep(
|
||||
async bool,
|
||||
total, forward, unembed, slice, logprobs, sample, pinSweep, asyncEval time.Duration,
|
||||
) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
pt.stepCalls++
|
||||
if async {
|
||||
pt.stepAsync++
|
||||
} else {
|
||||
pt.stepSync++
|
||||
}
|
||||
pt.stepTotalDur += total
|
||||
pt.forwardDur += forward
|
||||
pt.unembedDur += unembed
|
||||
pt.sliceDur += slice
|
||||
pt.logprobsDur += logprobs
|
||||
pt.sampleDur += sample
|
||||
pt.pinSweepDur += pinSweep
|
||||
pt.asyncEvalDur += asyncEval
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) recordSampleInt(d time.Duration, decodeCount int) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
pt.sampleInts++
|
||||
pt.sampleIntDur += d
|
||||
pt.maybeEmit(false, decodeCount)
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) maybeEmit(force bool, decodeCount int) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
if !force {
|
||||
if pt.every <= 0 || decodeCount <= 0 || decodeCount%pt.every != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
if pt.lastEmitCount == decodeCount {
|
||||
return
|
||||
}
|
||||
pt.lastEmitCount = decodeCount
|
||||
|
||||
msAvg := func(d time.Duration, n int) float64 {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(d) / float64(n) / float64(time.Millisecond)
|
||||
}
|
||||
stepResidual := pt.stepTotalDur - pt.forwardDur - pt.unembedDur - pt.sliceDur - pt.logprobsDur - pt.sampleDur - pt.pinSweepDur - pt.asyncEvalDur
|
||||
if stepResidual < 0 {
|
||||
stepResidual = 0
|
||||
}
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"mlx pipeline timing: decode=%d step_calls=%d step_async=%d step_sync=%d avg_step_ms=%.2f fwd_ms=%.2f unembed_ms=%.2f slice_ms=%.2f logprobs_ms=%.2f sample_ms=%.2f pin_sweep_ms=%.2f async_eval_ms=%.2f step_residual_ms=%.2f sample_int_ms=%.2f\n",
|
||||
decodeCount,
|
||||
pt.stepCalls,
|
||||
pt.stepAsync,
|
||||
pt.stepSync,
|
||||
msAvg(pt.stepTotalDur, pt.stepCalls),
|
||||
msAvg(pt.forwardDur, pt.stepCalls),
|
||||
msAvg(pt.unembedDur, pt.stepCalls),
|
||||
msAvg(pt.sliceDur, pt.stepCalls),
|
||||
msAvg(pt.logprobsDur, pt.stepCalls),
|
||||
msAvg(pt.sampleDur, pt.stepCalls),
|
||||
msAvg(pt.pinSweepDur, pt.stepCalls),
|
||||
msAvg(pt.asyncEvalDur, pt.stepCalls),
|
||||
msAvg(stepResidual, pt.stepCalls),
|
||||
msAvg(pt.sampleIntDur, pt.sampleInts),
|
||||
)
|
||||
}
|
||||
|
||||
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
|
||||
if usePromptCache {
|
||||
insertCache()
|
||||
logMemory("request_done_cached", -1)
|
||||
return
|
||||
}
|
||||
freeCaches()
|
||||
logMemory("request_done_freed", -1)
|
||||
}
|
||||
|
||||
func recordCacheCheckpoints(caches []cache.Cache, pos int) {
|
||||
if pos <= 0 {
|
||||
return
|
||||
}
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if recorder, ok := c.(cache.CheckpointRecorder); ok {
|
||||
recorder.RecordCheckpoint(pos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func freeOwnedCaches(caches []cache.Cache) {
|
||||
for i, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
c.Free()
|
||||
caches[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
@@ -31,7 +262,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
usePromptCache := true
|
||||
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
|
||||
usePromptCache = false
|
||||
}
|
||||
lowMemoryDecode := !usePromptCache
|
||||
if m, ok := r.Model.(interface{ LowMemoryDecode() bool }); ok {
|
||||
lowMemoryDecode = m.LowMemoryDecode()
|
||||
}
|
||||
prefillChunk := prefillChunkSize(lowMemoryDecode)
|
||||
|
||||
var caches []cache.Cache
|
||||
var tokens []int32
|
||||
if usePromptCache {
|
||||
caches, tokens = r.FindNearestCache(inputs)
|
||||
} else {
|
||||
tokens = inputs
|
||||
}
|
||||
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
@@ -43,40 +291,140 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
materializeRecurrentCaches := func() bool {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := c.(*cache.RecurrentCache); !ok {
|
||||
continue
|
||||
}
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return false
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
return true
|
||||
}
|
||||
freeCaches := func() {
|
||||
// Non-prompt-cache requests allocate fresh caches every generation.
|
||||
// Explicitly free cache-owned state (including recurrent checkpoints),
|
||||
// then sweep remaining intermediates.
|
||||
freeOwnedCaches(caches)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
debugMemory := mlxDebugMemoryEnabled()
|
||||
hasRecurrent := hasRecurrentCaches(caches)
|
||||
asyncRecurrentMaterializeEvery := recurrentMaterializeInterval(lowMemoryDecode, hasRecurrent)
|
||||
computeStepLogprobs := mlxComputeLogprobsEnabled()
|
||||
pipelineTiming := newPipelineTiming()
|
||||
logMemory := func(phase string, token int) {
|
||||
if !debugMemory {
|
||||
return
|
||||
}
|
||||
if token >= 0 {
|
||||
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
|
||||
return
|
||||
}
|
||||
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
|
||||
}
|
||||
logMemory("prefill_start", -1)
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
materializeCaches()
|
||||
recordCacheCheckpoints(caches, processed+n)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
logMemory("prefill_done", -1)
|
||||
|
||||
step := func(token *mlx.Array, async bool) (*mlx.Array, *mlx.Array) {
|
||||
var t0, t time.Time
|
||||
var forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur time.Duration
|
||||
if pipelineTiming != nil {
|
||||
t0 = time.Now()
|
||||
t = t0
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
if pipelineTiming != nil {
|
||||
forwardDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
sample := request.Sample(logprobs)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
if pipelineTiming != nil {
|
||||
unembedDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
if pipelineTiming != nil {
|
||||
sliceDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
var logprobs *mlx.Array
|
||||
sampleInput := logits
|
||||
if computeStepLogprobs {
|
||||
logprobs = logits.Subtract(logits.Logsumexp(true))
|
||||
sampleInput = logprobs
|
||||
}
|
||||
if pipelineTiming != nil {
|
||||
logprobsDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
sample := request.Sample(sampleInput)
|
||||
if pipelineTiming != nil {
|
||||
sampleDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
if pipelineTiming != nil {
|
||||
pinSweepDur = time.Since(t)
|
||||
}
|
||||
if async {
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
if pipelineTiming != nil {
|
||||
asyncEvalDur = time.Since(t)
|
||||
}
|
||||
}
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.recordStep(async, time.Since(t0), forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur)
|
||||
}
|
||||
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed), !lowMemoryDecode)
|
||||
if lowMemoryDecode {
|
||||
// Materialize cache updates to prevent transform graph growth.
|
||||
materializeCaches()
|
||||
}
|
||||
recordCacheCheckpoints(caches, total)
|
||||
logMemory("decode_init", -1)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -84,20 +432,34 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
|
||||
var nextSample, nextLogprobs *mlx.Array
|
||||
if !lowMemoryDecode {
|
||||
nextSample, nextLogprobs = step(sample, true)
|
||||
}
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
logMemory("decode_first_eval", i)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
var intWaitStart time.Time
|
||||
if pipelineTiming != nil {
|
||||
intWaitStart = time.Now()
|
||||
}
|
||||
output := int32(sample.Int())
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.recordSampleInt(time.Since(intWaitStart), len(outputs)+1)
|
||||
}
|
||||
outputs = append(outputs, output)
|
||||
if !lowMemoryDecode {
|
||||
recordCacheCheckpoints(caches, total+len(outputs))
|
||||
}
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
@@ -109,18 +471,53 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
// For recurrent linear-attention models, avoid async prefetch to reduce
|
||||
// peak memory and clear allocator cache every token.
|
||||
if lowMemoryDecode {
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
if i+1 >= request.Options.MaxTokens {
|
||||
break
|
||||
}
|
||||
mlx.ClearCache()
|
||||
sample, logprobs = step(mlx.FromValues([]int32{output}, 1), false)
|
||||
// Materialize cache updates to avoid unbounded transform chains.
|
||||
materializeCaches()
|
||||
recordCacheCheckpoints(caches, total+len(outputs))
|
||||
if i%32 == 0 {
|
||||
logMemory("decode_lowmem_step", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
if asyncRecurrentMaterializeEvery > 0 && (i+1)%asyncRecurrentMaterializeEvery == 0 {
|
||||
if materializeRecurrentCaches() {
|
||||
mlx.Sweep()
|
||||
logMemory("decode_async_recurrent_materialize", i)
|
||||
}
|
||||
}
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
if i%64 == 0 {
|
||||
logMemory("decode_async_step", i)
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.maybeEmit(true, len(outputs))
|
||||
}
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
finalizeRequestCaches(usePromptCache,
|
||||
func() { r.InsertCache(append(inputs, outputs...), caches) },
|
||||
freeCaches,
|
||||
logMemory,
|
||||
)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
@@ -129,7 +526,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
209
x/mlxrunner/pipeline_helpers_test.go
Normal file
209
x/mlxrunner/pipeline_helpers_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type stubCache struct {
|
||||
freeCalls int
|
||||
}
|
||||
|
||||
func (s *stubCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
|
||||
func (s *stubCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
|
||||
func (s *stubCache) Materialize() []*mlx.Array { return nil }
|
||||
func (s *stubCache) CanTrim() bool { return true }
|
||||
func (s *stubCache) Trim(int) int { return 0 }
|
||||
func (s *stubCache) Clone() cache.Cache { return s }
|
||||
func (s *stubCache) Free() { s.freeCalls++ }
|
||||
func (s *stubCache) Offset() int { return 0 }
|
||||
func (s *stubCache) Len() int { return 0 }
|
||||
|
||||
func TestPrefillChunkSize(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
|
||||
if got := prefillChunkSize(false); got != 2<<10 {
|
||||
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 32 {
|
||||
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
|
||||
if got := prefillChunkSize(false); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXDebugMemoryEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
|
||||
if mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
|
||||
if !mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasRecurrentCaches(t *testing.T) {
|
||||
if hasRecurrentCaches(nil) {
|
||||
t.Fatal("hasRecurrentCaches(nil) = true, want false")
|
||||
}
|
||||
|
||||
if hasRecurrentCaches([]cache.Cache{cache.NewKVCache()}) {
|
||||
t.Fatal("hasRecurrentCaches(kv-only) = true, want false")
|
||||
}
|
||||
|
||||
rc := cache.NewRecurrentCache(4, 8, 2, 16, 8)
|
||||
if !hasRecurrentCaches([]cache.Cache{cache.NewKVCache(), rc}) {
|
||||
t.Fatal("hasRecurrentCaches(mixed) = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentMaterializeInterval(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "")
|
||||
|
||||
if got := recurrentMaterializeInterval(true, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(lowmem=true, recurrent=true) = %d, want 0", got)
|
||||
}
|
||||
if got := recurrentMaterializeInterval(false, false); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(lowmem=false, recurrent=false) = %d, want 0", got)
|
||||
}
|
||||
if got := recurrentMaterializeInterval(false, true); got != defaultRecurrentMaterializeInterval {
|
||||
t.Fatalf("recurrentMaterializeInterval(default) = %d, want %d", got, defaultRecurrentMaterializeInterval)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "16")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 16 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=16) = %d, want 16", got)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "0")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=0) = %d, want 0", got)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "-1")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=-1) = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXPipelineTimingConfig(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "")
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "")
|
||||
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
|
||||
t.Fatalf("mlxPipelineTimingConfig() = (%v, %d), want (false, 0)", enabled, every)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "1")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled default) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "16")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != 16 {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled env=16) = (%v, %d), want (true, 16)", enabled, every)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "0")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled env=0) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "0")
|
||||
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
|
||||
t.Fatalf("mlxPipelineTimingConfig(disabled) = (%v, %d), want (false, 0)", enabled, every)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXComputeLogprobsEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "")
|
||||
if mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "1")
|
||||
if !mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = false with env=1, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "0")
|
||||
if mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = true with env=0, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
true,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 1 {
|
||||
t.Fatalf("insert calls = %d, want 1", insertCalls)
|
||||
}
|
||||
if freeCalls != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_cached" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
false,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 0 {
|
||||
t.Fatalf("insert calls = %d, want 0", insertCalls)
|
||||
}
|
||||
if freeCalls != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_freed" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFreeOwnedCaches(t *testing.T) {
|
||||
a := &stubCache{}
|
||||
b := &stubCache{}
|
||||
caches := []cache.Cache{a, nil, b}
|
||||
|
||||
freeOwnedCaches(caches)
|
||||
|
||||
if a.freeCalls != 1 {
|
||||
t.Fatalf("a free calls = %d, want 1", a.freeCalls)
|
||||
}
|
||||
if b.freeCalls != 1 {
|
||||
t.Fatalf("b free calls = %d, want 1", b.freeCalls)
|
||||
}
|
||||
if caches[0] != nil || caches[2] != nil {
|
||||
t.Fatalf("cache entries not nilled after free: %#v", caches)
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,39 @@ type Runner struct {
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
caches []*HybridCacheEntry
|
||||
}
|
||||
|
||||
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
|
||||
if len(tensors) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seen := make(map[*mlx.Array]bool, len(tensors))
|
||||
toRelease := make([]*mlx.Array, 0, len(tensors))
|
||||
for name, arr := range tensors {
|
||||
if arr == nil || !arr.Valid() {
|
||||
delete(tensors, name)
|
||||
continue
|
||||
}
|
||||
if keep != nil {
|
||||
if _, ok := keep[arr]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(tensors, name)
|
||||
if seen[arr] {
|
||||
continue
|
||||
}
|
||||
seen[arr] = true
|
||||
toRelease = append(toRelease, arr)
|
||||
}
|
||||
|
||||
if len(toRelease) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return len(toRelease), mlx.Release(toRelease...)
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -85,9 +118,33 @@ func (r *Runner) Load(modelName string) error {
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
|
||||
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
return err
|
||||
}
|
||||
|
||||
// Materialize model-owned roots before releasing source tensor handles, then
|
||||
// pin only those roots. This avoids retaining large load-time intermediates
|
||||
// while still protecting shared model tensors from Sweep.
|
||||
roots := mlx.Collect(m)
|
||||
mlx.Eval(roots...)
|
||||
mlx.Pin(roots...)
|
||||
|
||||
keep := make(map[*mlx.Array]struct{})
|
||||
for _, arr := range roots {
|
||||
if arr != nil && arr.Valid() {
|
||||
keep[arr] = struct{}{}
|
||||
}
|
||||
}
|
||||
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
|
||||
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
return nil
|
||||
|
||||
@@ -15,6 +15,40 @@ type LinearLayer interface {
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// Conv1d applies 1D convolution over NLC input.
|
||||
type Conv1d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Stride int32
|
||||
Padding int32
|
||||
Dilation int32
|
||||
Groups int32
|
||||
}
|
||||
|
||||
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||
if stride <= 0 {
|
||||
stride = 1
|
||||
}
|
||||
if dilation <= 0 {
|
||||
dilation = 1
|
||||
}
|
||||
if groups <= 0 {
|
||||
groups = 1
|
||||
}
|
||||
return &Conv1d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
Dilation: dilation,
|
||||
Groups: groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
type Linear struct {
|
||||
Weight *mlx.Array
|
||||
|
||||
1913
x/models/qwen3_5/qwen3_5.go
Normal file
1913
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
206
x/models/qwen3_5/qwen3_5_test.go
Normal file
206
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantContainer string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
key: "model.embed_tokens.weight",
|
||||
wantContainer: "",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model with inner model",
|
||||
key: "model.language_model.model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model without inner model",
|
||||
key: "model.language_model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||
tt.key: dummy,
|
||||
})
|
||||
|
||||
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||
t.Fatalf(
|
||||
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||
layout.containerPrefix,
|
||||
layout.modelPrefix,
|
||||
tt.wantContainer,
|
||||
tt.wantModel,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRuntimeToggles(t *testing.T) {
|
||||
m := &Model{}
|
||||
if m.DisablePromptCache() {
|
||||
t.Fatal("DisablePromptCache() = true, want false")
|
||||
}
|
||||
if m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = true, want false")
|
||||
}
|
||||
if !m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = false, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "0")
|
||||
if m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "1")
|
||||
if !m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = false with env override 1, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "0")
|
||||
if m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "1")
|
||||
if !m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = false with env override, want true")
|
||||
}
|
||||
|
||||
if !qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = false, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "0")
|
||||
if qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "1")
|
||||
if !qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = false with env override 1, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
Reference in New Issue
Block a user