Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick Devine
8e19bfd20d Add qwen3.5-next-moe support to MLX runner and models
This change:
  * adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner
  * introduces recurrent cache support and related MLX ops
  * updates pipeline/runner integration and adds tests
2026-02-24 23:32:59 -08:00
20 changed files with 4622 additions and 46 deletions

View File

@@ -296,13 +296,19 @@ func normalizeQuantType(quantize string) string {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
isStackedExpert := strings.Contains(name, ".mlp.experts.gate_up_proj") || strings.Contains(name, ".mlp.experts.down_proj")
// Use basic name-based check first
if !ShouldQuantize(name, "") {
if !isStackedExpert && !ShouldQuantize(name, "") {
return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
// Quantize 2D linear tensors by default. qwen3.5 stacked expert tensors are
// also eligible even though they are stored as 3D [experts, out, in].
if !isStackedExpert && len(shape) != 2 {
return ""
}
if isStackedExpert && len(shape) != 3 {
return ""
}
@@ -372,9 +378,10 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
}
// expertGroupRegexp matches expert tensor names and captures the group prefix.
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
// Matches nested and non-nested LLM prefixes and both per-expert ".weight"
// tensors and qwen3.5 stacked expert tensors without ".weight".
// Captures: model(.language_model(.model)?).layers.{L}.mlp.experts or .shared_experts
var expertGroupRegexp = regexp.MustCompile(`^(model(?:\.language_model(?:\.model)?)?\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*(?:\.weight)?$`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example:

View File

@@ -557,6 +557,9 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// qwen3.5 stacked expert tensors are an exception: 3D [experts, out, in]
{"qwen3.5 stacked gate_up_proj", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int4", true},
{"qwen3.5 stacked down_proj", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int4", true},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -586,6 +589,42 @@ func TestShouldQuantizeTensor(t *testing.T) {
}
}
func TestGetTensorQuantization_StackedExperts(t *testing.T) {
tests := []struct {
name string
shape []int32
quantize string
want string
}{
{
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
shape: []int32{256, 1024, 2048},
quantize: "int4",
want: "int4",
},
{
name: "model.language_model.layers.0.mlp.experts.down_proj",
shape: []int32{256, 2048, 512},
quantize: "int4",
want: "int8",
},
{
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
shape: []int32{256, 1024, 2050}, // not divisible by 32
quantize: "int4",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetTensorQuantization(tt.name, tt.shape, tt.quantize); got != tt.want {
t.Fatalf("GetTensorQuantization(%q, %v, %q) = %q, want %q", tt.name, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestExpertGroupPrefix(t *testing.T) {
tests := []struct {
name string
@@ -595,10 +634,18 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.gate_up_proj", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.down_proj", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.gate_up_proj", "model.language_model.model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.down_proj", "model.language_model.model.layers.0.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
{"model.language_model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.layers.1.mlp.shared_experts"},
{"model.language_model.model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.model.layers.1.mlp.shared_experts"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts

View File

@@ -5,62 +5,341 @@ package mlxrunner
import (
"fmt"
"log/slog"
"os"
"strconv"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
const defaultPromptCacheBranches = 1
// HybridCacheEntry stores a single prompt branch with mixed cache types
// (e.g. KV + recurrent caches) and coordinates shared operations across them.
type HybridCacheEntry struct {
Tokens []int32
Caches []cache.Cache
}
// CacheEntry is kept as an alias for the current single-entry runner path.
// Future multi-entry cache stores should prefer HybridCacheEntry directly.
type CacheEntry = HybridCacheEntry
func promptCacheBranchLimit() int {
if v := os.Getenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultPromptCacheBranches
}
func cloneTokens(tokens []int32) []int32 {
out := make([]int32, len(tokens))
copy(out, tokens)
return out
}
func equalTokens(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func (r *Runner) cacheStore() []*HybridCacheEntry {
if len(r.caches) == 0 && r.cache != nil {
r.caches = []*HybridCacheEntry{r.cache}
}
if r.cache == nil && len(r.caches) > 0 {
r.cache = r.caches[0]
}
return r.caches
}
func (r *Runner) setCacheStore(entries []*HybridCacheEntry) {
r.caches = entries
if len(entries) == 0 {
r.cache = nil
return
}
r.cache = entries[0]
}
func (r *Runner) touchCacheEntry(idx int) {
if idx <= 0 || idx >= len(r.caches) {
return
}
e := r.caches[idx]
copy(r.caches[1:idx+1], r.caches[:idx])
r.caches[0] = e
r.cache = r.caches[0]
}
func (r *Runner) bestCacheEntry(tokens []int32) (idx int, prefix int) {
bestIdx, bestPrefix := -1, 0
for i, e := range r.cacheStore() {
if e == nil {
continue
}
p := e.PrefixLen(tokens)
if p > bestPrefix {
bestIdx, bestPrefix = i, p
}
}
return bestIdx, bestPrefix
}
func (e *HybridCacheEntry) PrefixLen(tokens []int32) int {
if e == nil {
return 0
}
prefix := 0
for prefix < len(tokens) && prefix < len(e.Tokens) && tokens[prefix] == e.Tokens[prefix] {
prefix++
}
return prefix
}
func (e *HybridCacheEntry) Free() {
if e == nil {
return
}
for _, c := range e.Caches {
if c != nil {
c.Free()
}
}
}
func (e *HybridCacheEntry) cachesSlice() []cache.Cache {
if e == nil {
return nil
}
return e.Caches
}
func (e *HybridCacheEntry) cachesCanTrim() bool {
if e == nil {
return false
}
for _, c := range e.Caches {
if c == nil {
continue
}
if !c.CanTrim() {
return false
}
}
return true
}
func (e *HybridCacheEntry) TrimToPrefix(prefix int) {
if e == nil {
return
}
for _, c := range e.Caches {
if c == nil || !c.CanTrim() {
continue
}
trim := c.Offset() - prefix
if trim > 0 {
c.Trim(trim)
}
}
if prefix < len(e.Tokens) {
e.Tokens = e.Tokens[:prefix]
}
}
func (e *HybridCacheEntry) RestoreToPrefix(target int) (int, bool) {
if e == nil {
return 0, false
}
restorePos := -1
sawNonTrimmable := false
for _, c := range e.Caches {
if c == nil || c.CanTrim() {
continue
}
sawNonTrimmable = true
restorer, ok := c.(cache.CheckpointRestorer)
if !ok {
return 0, false
}
pos, ok := restorer.BestCheckpoint(target)
if !ok {
return 0, false
}
if restorePos < 0 {
restorePos = pos
continue
}
if pos != restorePos {
return 0, false
}
}
if !sawNonTrimmable || restorePos < 0 {
return 0, false
}
e.TrimToPrefix(restorePos)
for _, c := range e.Caches {
if c == nil || c.CanTrim() {
continue
}
restorer, ok := c.(cache.CheckpointRestorer)
if !ok || !restorer.RestoreCheckpoint(restorePos) {
return 0, false
}
}
if restorePos < len(e.Tokens) {
e.Tokens = e.Tokens[:restorePos]
}
return restorePos, true
}
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
entries := r.cacheStore()
if len(entries) == 0 {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
// Find longest common prefix
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
prefix++
branchLimit := promptCacheBranchLimit()
idx, prefix := r.bestCacheEntry(tokens)
if idx < 0 {
if branchLimit <= 1 && len(entries) == 1 && entries[0] != nil {
entries[0].Free()
r.setCacheStore(nil)
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
if idx > 0 {
r.touchCacheEntry(idx)
}
base := r.cache
if base == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
working := base
forked := false
if branchLimit > 1 && prefix > 0 {
working = base.Clone()
forked = true
}
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
if !forked && branchLimit <= 1 {
base.Free()
r.setCacheStore(nil)
}
r.cache = nil
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
case prefix < len(working.Tokens):
if !working.cachesCanTrim() {
if restorePos, ok := working.RestoreToPrefix(prefix); ok {
slog.Info("Cache restore", "total", len(tokens), "matched", prefix, "restored", restorePos, "left", len(tokens[restorePos:]))
return working.cachesSlice(), tokens[restorePos:]
}
if forked {
working.Free()
} else if branchLimit <= 1 {
base.Free()
r.setCacheStore(nil)
}
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return nil, tokens
}
r.cache.Tokens = r.cache.Tokens[:prefix]
working.TrimToPrefix(prefix)
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
return working.cachesSlice(), tokens[prefix:]
}
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
entry := &HybridCacheEntry{
Tokens: cloneTokens(tokens),
Caches: caches,
}
branchLimit := promptCacheBranchLimit()
if branchLimit <= 1 {
r.setCacheStore([]*HybridCacheEntry{entry})
return
}
entries := r.cacheStore()
// Replace any exact-token duplicate branch with the new result.
for i := 0; i < len(entries); i++ {
if entries[i] == nil || !equalTokens(entries[i].Tokens, entry.Tokens) {
continue
}
entries[i].Free()
entries = append(entries[:i], entries[i+1:]...)
break
}
entries = append([]*HybridCacheEntry{entry}, entries...)
if len(entries) > branchLimit {
for _, evicted := range entries[branchLimit:] {
if evicted != nil {
evicted.Free()
}
}
entries = entries[:branchLimit]
}
r.setCacheStore(entries)
}
func (c *HybridCacheEntry) Clone() *HybridCacheEntry {
if c == nil {
return nil
}
tokens := make([]int32, len(c.Tokens))
copy(tokens, c.Tokens)
caches := make([]cache.Cache, len(c.Caches))
for i, cc := range c.Caches {
if cc != nil {
caches[i] = cc.Clone()
}
}
return &HybridCacheEntry{
Tokens: tokens,
Caches: caches,
}
}
func (c *CacheEntry) LogCache() {
func (c *HybridCacheEntry) LogCache() {
if c == nil || len(c.Caches) == 0 {
return
}
var totalBytes int
for _, kv := range c.Caches {
if kv == nil {
continue
}
k, v := kv.State()
if k == nil || v == nil {
continue
}
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))

View File

@@ -3,13 +3,22 @@
package cache
import (
"log/slog"
"os"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func kvCacheGrowDebugEnabled() bool {
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
}
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Materialize() []*mlx.Array
CanTrim() bool
Trim(int) int
Clone() Cache
Free()
@@ -17,6 +26,19 @@ type Cache interface {
Len() int
}
// CheckpointRecorder is an optional cache capability for recording recurrent
// state snapshots at specific token positions.
type CheckpointRecorder interface {
RecordCheckpoint(pos int)
}
// CheckpointRestorer is an optional cache capability for restoring recurrent
// state to a previously recorded checkpoint.
type CheckpointRestorer interface {
BestCheckpoint(target int) (pos int, ok bool)
RestoreCheckpoint(pos int) bool
}
type KVCache struct {
keys, values *mlx.Array
offset int
@@ -49,6 +71,9 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
if kvCacheGrowDebugEnabled() {
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
}
}
c.offset += L
@@ -67,6 +92,19 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.keys != nil && c.keys.Valid() {
out = append(out, c.keys)
}
if c.values != nil && c.values.Valid() {
out = append(out, c.values)
}
return out
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
@@ -190,6 +228,8 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
return c.keys, c.values
}
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n

17
x/mlxrunner/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,17 @@
//go:build mlx
package cache
import "testing"
func TestKVCacheGrowDebugEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
if kvCacheGrowDebugEnabled() {
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
if !kvCacheGrowDebugEnabled() {
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
}
}

519
x/mlxrunner/cache/recurrent.go vendored Normal file
View File

@@ -0,0 +1,519 @@
//go:build mlx
package cache
import (
"os"
"strconv"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
const (
defaultRecurrentCheckpointCount = 32
defaultRecurrentCheckpointInterval = 128
defaultRecurrentCheckpointMinPos = 16
)
type recurrentCheckpoint struct {
pos int
convState *mlx.Array
deltaState *mlx.Array
}
type recurrentSlot struct {
convState *mlx.Array
deltaState *mlx.Array
checkpoints []recurrentCheckpoint
checkpointSize int
checkpointNext int
checkpointLastPos int
refs int
}
func getenvInt(name string, def int) int {
if v := os.Getenv(name); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return def
}
func recurrentCheckpointConfig() (count, interval, minPos int) {
count = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINTS", defaultRecurrentCheckpointCount)
interval = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", defaultRecurrentCheckpointInterval)
minPos = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", defaultRecurrentCheckpointMinPos)
if count < 0 {
count = 0
}
if interval < 0 {
interval = 0
}
if minPos < 0 {
minPos = 0
}
return count, interval, minPos
}
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
slot *recurrentSlot
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
checkpointCount int
checkpointInterval int
checkpointMinPos int
}
func newRecurrentSlot(checkpointCount int) *recurrentSlot {
s := &recurrentSlot{
refs: 1,
checkpointLastPos: -1,
}
if checkpointCount > 0 {
s.checkpoints = make([]recurrentCheckpoint, checkpointCount)
for i := range s.checkpoints {
s.checkpoints[i].pos = -1
}
}
return s
}
func retainRecurrentSlot(s *recurrentSlot) *recurrentSlot {
if s != nil {
s.refs++
}
return s
}
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
// Break dependency chains so recurrent state does not retain the full
// per-token compute graph over time.
snap := mlx.Snapshot(v)
mlx.Eval(snap)
old := *dst
*dst = snap
mlx.Pin(snap)
// Release previous cached state root, then recursively free the transient
// incoming graph root now that a detached snapshot is retained in cache.
if old != nil && old != snap {
mlx.Release(old)
}
if v != snap && v != old {
mlx.Release(v)
}
}
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
old := *dst
*dst = v
mlx.Pin(v)
if old != nil && old != v {
mlx.Release(old)
}
}
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := mlx.Detach(root)
old := *dst
*dst = detached
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Release(old)
}
// Intentionally do not force-release root/v here. In the fast path, the detached
// handle aliases the same MLX value and may still be lazily computed. Releasing the
// source handles can invalidate the cached state before the next eval/sweep point.
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Snapshot(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
count, interval, minPos := recurrentCheckpointConfig()
c := &RecurrentCache{
slot: newRecurrentSlot(count),
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
checkpointCount: count,
checkpointInterval: interval,
checkpointMinPos: minPos,
}
return c
}
func clonePinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
clone := a.Clone()
mlx.Pin(clone)
return clone
}
func releaseCheckpointEntry(e *recurrentCheckpoint) {
mlx.Release(e.convState, e.deltaState)
e.convState, e.deltaState = nil, nil
e.pos = -1
}
func releaseRecurrentSlot(s *recurrentSlot) {
if s == nil {
return
}
s.refs--
if s.refs > 0 {
return
}
mlx.Release(s.convState, s.deltaState)
s.convState, s.deltaState = nil, nil
for i := range s.checkpoints {
releaseCheckpointEntry(&s.checkpoints[i])
}
s.checkpointSize = 0
s.checkpointNext = 0
s.checkpointLastPos = -1
}
func cloneRecurrentSlot(src *recurrentSlot) *recurrentSlot {
if src == nil {
return nil
}
dst := &recurrentSlot{
checkpointSize: src.checkpointSize,
checkpointNext: src.checkpointNext,
checkpointLastPos: src.checkpointLastPos,
refs: 1,
}
if src.convState != nil && src.convState.Valid() {
dst.convState = snapshotPinned(src.convState)
}
if src.deltaState != nil && src.deltaState.Valid() {
dst.deltaState = snapshotPinned(src.deltaState)
}
if len(src.checkpoints) > 0 {
dst.checkpoints = make([]recurrentCheckpoint, len(src.checkpoints))
for i := range src.checkpoints {
dst.checkpoints[i].pos = src.checkpoints[i].pos
if src.checkpoints[i].pos < 0 {
continue
}
dst.checkpoints[i].convState = snapshotPinned(src.checkpoints[i].convState)
dst.checkpoints[i].deltaState = snapshotPinned(src.checkpoints[i].deltaState)
}
}
return dst
}
func (c *RecurrentCache) slotOrInit() *recurrentSlot {
if c.slot == nil {
c.slot = newRecurrentSlot(c.checkpointCount)
}
return c.slot
}
func (c *RecurrentCache) ensureWritableSlot() *recurrentSlot {
s := c.slotOrInit()
if s.refs <= 1 {
return s
}
c.slot = cloneRecurrentSlot(s)
s.refs--
return c.slot
}
func (c *RecurrentCache) pruneCheckpointsAfter(pos int) {
s := c.ensureWritableSlot()
if len(s.checkpoints) == 0 {
return
}
size := 0
next := -1
last := -1
minPos := int(^uint(0) >> 1)
minIdx := 0
for i := range s.checkpoints {
e := &s.checkpoints[i]
if e.pos > pos {
releaseCheckpointEntry(e)
}
if e.pos >= 0 {
size++
if e.pos > last {
last = e.pos
}
if e.pos < minPos {
minPos = e.pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.checkpointSize = size
s.checkpointLastPos = last
if size == 0 {
s.checkpointNext = 0
return
}
if next != -1 {
s.checkpointNext = next
return
}
s.checkpointNext = minIdx
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
s := c.slotOrInit()
needConv := s.convState == nil || s.convState.DType() != dtype ||
s.convState.Dim(0) != batch || s.convState.Dim(1) != c.convTail || s.convState.Dim(2) != c.convDim
needDelta := s.deltaState == nil || s.deltaState.DType() != dtype ||
s.deltaState.Dim(0) != batch || s.deltaState.Dim(1) != c.numVHeads || s.deltaState.Dim(2) != c.headVDim || s.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
s = c.ensureWritableSlot()
if needConv {
c.setStateRaw(&s.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.setStateRaw(&s.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.slotOrInit().convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateMaterialized(&s.convState, v)
}
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point. The conv-state input is usually a slice view, so request a
// compact contiguous copy to avoid pinning the whole source buffer.
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateDetached(&s.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.slotOrInit().deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateMaterialized(&s.deltaState, v)
}
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point.
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateDetached(&s.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
c.ensure(1, mlx.DTypeFloat32)
s := c.slotOrInit()
return s.convState, s.deltaState
}
func (c *RecurrentCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
s := c.slot
if s == nil {
return out
}
if s.convState != nil && s.convState.Valid() {
out = append(out, s.convState)
}
if s.deltaState != nil && s.deltaState.Valid() {
out = append(out, s.deltaState)
}
return out
}
func (c *RecurrentCache) RecordCheckpoint(pos int) {
s := c.slot
if s == nil || len(s.checkpoints) == 0 || pos <= 0 || pos < c.checkpointMinPos {
return
}
if c.offset != pos {
// Checkpoints are keyed by logical token position. Ignore callers with a
// mismatched position to avoid restoring inconsistent recurrent state.
return
}
if s.convState == nil || s.deltaState == nil || !s.convState.Valid() || !s.deltaState.Valid() {
return
}
if s.checkpointLastPos == pos {
return
}
if s.checkpointLastPos >= 0 && c.checkpointInterval > 0 && pos-s.checkpointLastPos < c.checkpointInterval {
return
}
if s.refs > 1 {
s = c.ensureWritableSlot()
}
idx := s.checkpointNext
e := &s.checkpoints[idx]
releaseCheckpointEntry(e)
e.pos = pos
e.convState = clonePinned(s.convState)
e.deltaState = clonePinned(s.deltaState)
s.checkpointNext = (idx + 1) % len(s.checkpoints)
if s.checkpointSize < len(s.checkpoints) {
s.checkpointSize++
}
s.checkpointLastPos = pos
}
func (c *RecurrentCache) BestCheckpoint(target int) (pos int, ok bool) {
s := c.slot
if s == nil {
return 0, false
}
best := -1
for i := range s.checkpoints {
pos := s.checkpoints[i].pos
if pos < 0 || pos > target {
continue
}
if pos > best {
best = pos
}
}
if best < 0 {
return 0, false
}
return best, true
}
func (c *RecurrentCache) RestoreCheckpoint(pos int) bool {
if pos < 0 {
return false
}
s := c.ensureWritableSlot()
for i := range s.checkpoints {
e := &s.checkpoints[i]
if e.pos != pos {
continue
}
if e.convState == nil || e.deltaState == nil || !e.convState.Valid() || !e.deltaState.Valid() {
return false
}
c.setStateRaw(&s.convState, e.convState.Clone())
c.setStateRaw(&s.deltaState, e.deltaState.Clone())
c.offset = pos
c.pruneCheckpointsAfter(pos)
return true
}
return false
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable; callers should use
// checkpoint-based restore instead.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
slot: retainRecurrentSlot(c.slotOrInit()),
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
checkpointCount: c.checkpointCount,
checkpointInterval: c.checkpointInterval,
checkpointMinPos: c.checkpointMinPos,
}
return clone
}
func (c *RecurrentCache) Free() {
releaseRecurrentSlot(c.slot)
c.slot = nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

125
x/mlxrunner/cache/recurrent_cow_test.go vendored Normal file
View File

@@ -0,0 +1,125 @@
//go:build mlx
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func requireMLXRuntime(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX runtime unavailable: %v", err)
}
}
func newTestRecurrentCache(t *testing.T) *RecurrentCache {
t.Helper()
requireMLXRuntime(t)
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINTS", "2")
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", "0")
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", "0")
c := NewRecurrentCache(2, 3, 1, 2, 2)
_ = c.ConvState(1, mlx.DTypeFloat32)
_ = c.DeltaState(1, mlx.DTypeFloat32)
return c
}
func TestRecurrentCacheCloneSharesSlotUntilMutation(t *testing.T) {
c1 := newTestRecurrentCache(t)
t.Cleanup(func() {
c1.Free()
mlx.Sweep()
})
c1.Advance(8)
c1.RecordCheckpoint(8)
if got, ok := c1.BestCheckpoint(8); !ok || got != 8 {
t.Fatalf("BestCheckpoint(8) = (%d, %v), want (8, true)", got, ok)
}
c2 := c1.Clone().(*RecurrentCache)
t.Cleanup(func() {
c2.Free()
mlx.Sweep()
})
if c1.slot == nil || c2.slot == nil {
t.Fatal("expected non-nil shared slots")
}
if c1.slot != c2.slot {
t.Fatal("clone did not share recurrent slot")
}
if c1.slot.refs != 2 {
t.Fatalf("shared slot refs = %d, want 2", c1.slot.refs)
}
// Read access should not trigger a COW detach.
_ = c2.ConvState(1, mlx.DTypeFloat32)
_ = c2.DeltaState(1, mlx.DTypeFloat32)
if c1.slot != c2.slot {
t.Fatal("read access detached shared recurrent slot")
}
// Mutating recurrent state should detach and deep-copy checkpoint metadata.
c2.SetConvState(mlx.Zeros(mlx.DTypeFloat32, 1, 2, 3))
if c1.slot == c2.slot {
t.Fatal("SetConvState did not detach shared recurrent slot")
}
if c1.slot.refs != 1 || c2.slot.refs != 1 {
t.Fatalf("post-detach refs = (%d, %d), want (1, 1)", c1.slot.refs, c2.slot.refs)
}
if len(c1.slot.checkpoints) == 0 || len(c2.slot.checkpoints) == 0 {
t.Fatal("expected checkpoint ring to be preserved on detach")
}
if c1.slot.checkpoints[0].pos != c2.slot.checkpoints[0].pos {
t.Fatalf("checkpoint pos mismatch after detach: %d vs %d", c1.slot.checkpoints[0].pos, c2.slot.checkpoints[0].pos)
}
if c1.slot.checkpoints[0].pos != 8 {
t.Fatalf("checkpoint pos = %d, want 8", c1.slot.checkpoints[0].pos)
}
if c1.slot.checkpoints[0].convState == c2.slot.checkpoints[0].convState {
t.Fatal("checkpoint conv state was aliased after COW detach")
}
if c1.slot.checkpoints[0].deltaState == c2.slot.checkpoints[0].deltaState {
t.Fatal("checkpoint delta state was aliased after COW detach")
}
}
func TestRecurrentCacheFreeKeepsSharedCloneAlive(t *testing.T) {
c1 := newTestRecurrentCache(t)
c2 := c1.Clone().(*RecurrentCache)
t.Cleanup(func() {
c1.Free()
c2.Free()
mlx.Sweep()
})
if c2.slot == nil || c2.slot.refs != 2 {
t.Fatalf("shared clone refs = %d, want 2", func() int {
if c2.slot == nil {
return 0
}
return c2.slot.refs
}())
}
c1.Free()
if c2.slot == nil {
t.Fatal("clone slot was cleared after freeing sibling clone")
}
if c2.slot.refs != 1 {
t.Fatalf("clone slot refs after sibling Free = %d, want 1", c2.slot.refs)
}
if state := c2.ConvState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
t.Fatal("clone conv state invalid after freeing sibling clone")
}
if state := c2.DeltaState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
t.Fatal("clone delta state invalid after freeing sibling clone")
}
}

View File

@@ -0,0 +1,339 @@
//go:build mlx
package mlxrunner
import (
"reflect"
"testing"
cachepkg "github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type fakeCache struct {
canTrim bool
trims []int
freeCall int
offset int
}
func (f *fakeCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
func (f *fakeCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
func (f *fakeCache) Materialize() []*mlx.Array { return nil }
func (f *fakeCache) CanTrim() bool { return f.canTrim }
func (f *fakeCache) Trim(n int) int {
f.trims = append(f.trims, n)
f.offset -= n
return n
}
func (f *fakeCache) Clone() cachepkg.Cache { return &fakeCache{canTrim: f.canTrim, offset: f.offset} }
func (f *fakeCache) Free() { f.freeCall++ }
func (f *fakeCache) Offset() int { return f.offset }
func (f *fakeCache) Len() int { return f.offset }
type fakeCheckpointCache struct {
fakeCache
bestPos int
hasCheckpoint bool
restoreCalls []int
restoreSuccess bool
}
func (f *fakeCheckpointCache) BestCheckpoint(target int) (int, bool) {
if !f.hasCheckpoint || f.bestPos > target {
return 0, false
}
return f.bestPos, true
}
func (f *fakeCheckpointCache) RestoreCheckpoint(pos int) bool {
f.restoreCalls = append(f.restoreCalls, pos)
if !f.restoreSuccess || pos != f.bestPos {
return false
}
f.offset = pos
return true
}
func (f *fakeCheckpointCache) Clone() cachepkg.Cache {
clone := *f
clone.trims = nil
clone.restoreCalls = nil
return &clone
}
func TestFindNearestCacheReusesAppendOnlyNonTrimmableCache(t *testing.T) {
fc := &fakeCache{canTrim: false, offset: 2}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4})
if len(gotCaches) != 1 || gotCaches[0] != fc {
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
}
if want := []int32{3, 4}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 0 {
t.Fatalf("free calls = %d, want 0", fc.freeCall)
}
if len(fc.trims) != 0 {
t.Fatalf("trim calls = %v, want none", fc.trims)
}
}
func TestFindNearestCacheDropsNonTrimmableCacheOnDivergence(t *testing.T) {
fc := &fakeCache{canTrim: false, offset: 4}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if gotCaches != nil {
t.Fatalf("returned caches = %#v, want nil", gotCaches)
}
if want := []int32{1, 2, 9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 1 {
t.Fatalf("free calls = %d, want 1", fc.freeCall)
}
if len(fc.trims) != 0 {
t.Fatalf("trim calls = %v, want none", fc.trims)
}
if r.cache != nil {
t.Fatal("runner cache should be cleared on non-trimmable divergence")
}
}
func TestFindNearestCacheTrimsTrimmableCacheOnDivergence(t *testing.T) {
fc := &fakeCache{canTrim: true, offset: 4}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if len(gotCaches) != 1 || gotCaches[0] != fc {
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
}
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 0 {
t.Fatalf("free calls = %d, want 0", fc.freeCall)
}
if want := []int{2}; !reflect.DeepEqual(fc.trims, want) {
t.Fatalf("trim calls = %v, want %v", fc.trims, want)
}
if want := []int32{1, 2}; !reflect.DeepEqual(r.cache.Tokens, want) {
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
}
}
func TestFindNearestCacheRestoresCheckpointForNonTrimmableCaches(t *testing.T) {
kv := &fakeCache{canTrim: true, offset: 7}
rc1 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
rc2 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
Caches: []cachepkg.Cache{kv, rc1, rc2},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
if len(gotCaches) != 3 {
t.Fatalf("returned caches len = %d, want 3", len(gotCaches))
}
if want := []int32{8}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if want := []int{3}; !reflect.DeepEqual(kv.trims, want) {
t.Fatalf("kv trim calls = %v, want %v", kv.trims, want)
}
if want := []int{4}; !reflect.DeepEqual(rc1.restoreCalls, want) {
t.Fatalf("rc1 restore calls = %v, want %v", rc1.restoreCalls, want)
}
if want := []int{4}; !reflect.DeepEqual(rc2.restoreCalls, want) {
t.Fatalf("rc2 restore calls = %v, want %v", rc2.restoreCalls, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.cache.Tokens, want) {
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
}
}
func TestFindNearestCacheDropsOnMismatchedCheckpointRestorePoints(t *testing.T) {
rc1 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
rc2 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 3,
hasCheckpoint: true,
restoreSuccess: true,
}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
Caches: []cachepkg.Cache{rc1, rc2},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
if gotCaches != nil {
t.Fatalf("returned caches = %#v, want nil", gotCaches)
}
if want := []int32{1, 2, 3, 4, 8}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if rc1.freeCall != 1 || rc2.freeCall != 1 {
t.Fatalf("free calls = (%d,%d), want (1,1)", rc1.freeCall, rc2.freeCall)
}
if len(rc1.restoreCalls) != 0 || len(rc2.restoreCalls) != 0 {
t.Fatalf("restore calls = (%v,%v), want none", rc1.restoreCalls, rc2.restoreCalls)
}
}
func TestFindNearestCacheSelectsBestPrefixAcrossBranches(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "4")
short := &fakeCache{canTrim: true, offset: 2}
long := &fakeCache{canTrim: true, offset: 4}
shortEntry := &HybridCacheEntry{
Tokens: []int32{1, 2},
Caches: []cachepkg.Cache{short},
}
longEntry := &HybridCacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{long},
}
r := &Runner{
cache: shortEntry,
caches: []*HybridCacheEntry{shortEntry, longEntry},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 9})
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if len(gotCaches) != 1 {
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
}
if gotCaches[0] == long {
t.Fatal("expected cloned cache in multi-branch mode, got original branch cache")
}
if r.cache != longEntry || r.caches[0] != longEntry {
t.Fatal("best branch was not promoted to front of cache store")
}
}
func TestFindNearestCacheForksBranchWithCloneWhenMultiBranchEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
base := &fakeCache{canTrim: true, offset: 4}
baseEntry := &HybridCacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{base},
}
r := &Runner{
cache: baseEntry,
caches: []*HybridCacheEntry{baseEntry},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if len(gotCaches) != 1 {
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
}
clone, ok := gotCaches[0].(*fakeCache)
if !ok {
t.Fatalf("returned cache type = %T, want *fakeCache", gotCaches[0])
}
if clone == base {
t.Fatal("expected branch fork to return a cloned cache")
}
if len(base.trims) != 0 {
t.Fatalf("base branch trim calls = %v, want none", base.trims)
}
if want := []int{2}; !reflect.DeepEqual(clone.trims, want) {
t.Fatalf("forked branch trim calls = %v, want %v", clone.trims, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(baseEntry.Tokens, want) {
t.Fatalf("base entry tokens = %v, want %v", baseEntry.Tokens, want)
}
r.InsertCache([]int32{1, 2, 9}, gotCaches)
if len(r.caches) != 2 {
t.Fatalf("cache store len = %d, want 2", len(r.caches))
}
if want := []int32{1, 2, 9}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
t.Fatalf("new branch tokens = %v, want %v", r.caches[0].Tokens, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
t.Fatalf("preserved branch tokens = %v, want %v", r.caches[1].Tokens, want)
}
}
func TestInsertCacheEvictsOldestBranchWhenStoreFull(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
f1 := &fakeCache{canTrim: true, offset: 1}
f2 := &fakeCache{canTrim: true, offset: 2}
f3 := &fakeCache{canTrim: true, offset: 3}
r := &Runner{}
r.InsertCache([]int32{1}, []cachepkg.Cache{f1})
r.InsertCache([]int32{1, 2}, []cachepkg.Cache{f2})
r.InsertCache([]int32{1, 2, 3}, []cachepkg.Cache{f3})
if len(r.caches) != 2 {
t.Fatalf("cache store len = %d, want 2", len(r.caches))
}
if f1.freeCall != 1 {
t.Fatalf("oldest branch free calls = %d, want 1", f1.freeCall)
}
if f2.freeCall != 0 || f3.freeCall != 0 {
t.Fatalf("unexpected frees for retained branches: f2=%d f3=%d", f2.freeCall, f3.freeCall)
}
if want := []int32{1, 2, 3}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
t.Fatalf("MRU tokens = %v, want %v", r.caches[0].Tokens, want)
}
if want := []int32{1, 2}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
t.Fatalf("LRU tokens = %v, want %v", r.caches[1].Tokens, want)
}
}

View File

@@ -7,4 +7,6 @@ import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
)

View File

@@ -267,3 +267,20 @@ func LogArrays() {
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
}
// Release forcibly frees arrays regardless of reference accounting.
// Use only for arrays that are known to be unreachable by any live model state.
func Release(s ...*Array) (n int) {
seen := make(map[*Array]bool, len(s))
for _, t := range s {
if t == nil || !t.Valid() || seen[t] {
continue
}
seen[t] = true
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.pinned = false
t.ctx.ctx = nil
}
return n
}

View File

@@ -0,0 +1,275 @@
//go:build mlx
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"sync/atomic"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled atomic.Bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}

View File

@@ -19,7 +19,8 @@ func doEval(outputs []*Array, async bool) {
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output.Valid() {
// Callers may pass optional tensors (e.g. debug-only logprobs) as nil.
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}

View File

@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
return out
}
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("")
sinks := New("")
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
return out
}
@@ -378,6 +429,32 @@ func Collect(v any) []*Array {
return arrays
}
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
func Snapshot(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("SNAPSHOT")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// CollectReachable collects arrays from v and all transitive graph inputs.
func CollectReachable(v any) []*Array {
return Collect(v)
}
// Detach returns a new Array handle that shares the same MLX value but does
// not retain Go-side graph input references.
func Detach(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("DETACH")
C.mlx_array_set(&out.ctx, a.ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return

View File

@@ -6,7 +6,10 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"os"
"strconv"
"time"
"github.com/ollama/ollama/logutil"
@@ -14,6 +17,234 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
const defaultRecurrentMaterializeInterval = 64
const defaultPipelineTimingEvery = 64
func prefillChunkSize(lowMemoryDecode bool) int {
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
if lowMemoryDecode {
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
return 32
}
return 2 << 10
}
func hasRecurrentCaches(caches []cache.Cache) bool {
for _, c := range caches {
if c == nil {
continue
}
if _, ok := c.(*cache.RecurrentCache); ok {
return true
}
}
return false
}
// recurrentMaterializeInterval controls periodic recurrent-cache materialization
// during async decode. It exists to bound graph/handle growth when using fast
// recurrent cache writes; it is primarily a memory/stability tuning knob, not a
// throughput knob.
func recurrentMaterializeInterval(lowMemoryDecode bool, hasRecurrent bool) int {
if lowMemoryDecode || !hasRecurrent {
return 0
}
if v := os.Getenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
if n < 0 {
return 0
}
return n
}
}
return defaultRecurrentMaterializeInterval
}
func mlxDebugMemoryEnabled() bool {
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
}
// mlxPipelineTimingConfig controls runner-side decode pipeline timing logs. This
// is diagnostic-only and intentionally separate from model-specific timing.
func mlxPipelineTimingConfig() (enabled bool, every int) {
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_TIMING"); ok {
if parsed, err := strconv.ParseBool(v); err == nil {
enabled = parsed
}
}
if !enabled {
return false, 0
}
every = defaultPipelineTimingEvery
if v := os.Getenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
every = n
}
}
return true, every
}
// mlxComputeLogprobsEnabled restores the old decode-step logprob normalization
// path for profiling/experiments. It is off by default because the MLX runner
// does not currently populate Response.Logprobs.
func mlxComputeLogprobsEnabled() bool {
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS"); ok {
if enabled, err := strconv.ParseBool(v); err == nil {
return enabled
}
}
// The MLX runner currently does not populate Response.Logprobs, so skip the
// full-vocab logprob normalization path unless explicitly requested for
// debugging/experiments.
return false
}
type pipelineTiming struct {
every int
stepCalls int
stepAsync int
stepSync int
sampleInts int
stepTotalDur time.Duration
forwardDur time.Duration
unembedDur time.Duration
sliceDur time.Duration
logprobsDur time.Duration
sampleDur time.Duration
pinSweepDur time.Duration
asyncEvalDur time.Duration
sampleIntDur time.Duration
lastEmitCount int
}
func newPipelineTiming() *pipelineTiming {
enabled, every := mlxPipelineTimingConfig()
if !enabled {
return nil
}
pt := &pipelineTiming{every: every}
fmt.Fprintf(os.Stderr, "mlx pipeline timing: enabled every=%d\n", every)
return pt
}
func (pt *pipelineTiming) recordStep(
async bool,
total, forward, unembed, slice, logprobs, sample, pinSweep, asyncEval time.Duration,
) {
if pt == nil {
return
}
pt.stepCalls++
if async {
pt.stepAsync++
} else {
pt.stepSync++
}
pt.stepTotalDur += total
pt.forwardDur += forward
pt.unembedDur += unembed
pt.sliceDur += slice
pt.logprobsDur += logprobs
pt.sampleDur += sample
pt.pinSweepDur += pinSweep
pt.asyncEvalDur += asyncEval
}
func (pt *pipelineTiming) recordSampleInt(d time.Duration, decodeCount int) {
if pt == nil {
return
}
pt.sampleInts++
pt.sampleIntDur += d
pt.maybeEmit(false, decodeCount)
}
func (pt *pipelineTiming) maybeEmit(force bool, decodeCount int) {
if pt == nil {
return
}
if !force {
if pt.every <= 0 || decodeCount <= 0 || decodeCount%pt.every != 0 {
return
}
}
if pt.lastEmitCount == decodeCount {
return
}
pt.lastEmitCount = decodeCount
msAvg := func(d time.Duration, n int) float64 {
if n <= 0 {
return 0
}
return float64(d) / float64(n) / float64(time.Millisecond)
}
stepResidual := pt.stepTotalDur - pt.forwardDur - pt.unembedDur - pt.sliceDur - pt.logprobsDur - pt.sampleDur - pt.pinSweepDur - pt.asyncEvalDur
if stepResidual < 0 {
stepResidual = 0
}
fmt.Fprintf(
os.Stderr,
"mlx pipeline timing: decode=%d step_calls=%d step_async=%d step_sync=%d avg_step_ms=%.2f fwd_ms=%.2f unembed_ms=%.2f slice_ms=%.2f logprobs_ms=%.2f sample_ms=%.2f pin_sweep_ms=%.2f async_eval_ms=%.2f step_residual_ms=%.2f sample_int_ms=%.2f\n",
decodeCount,
pt.stepCalls,
pt.stepAsync,
pt.stepSync,
msAvg(pt.stepTotalDur, pt.stepCalls),
msAvg(pt.forwardDur, pt.stepCalls),
msAvg(pt.unembedDur, pt.stepCalls),
msAvg(pt.sliceDur, pt.stepCalls),
msAvg(pt.logprobsDur, pt.stepCalls),
msAvg(pt.sampleDur, pt.stepCalls),
msAvg(pt.pinSweepDur, pt.stepCalls),
msAvg(pt.asyncEvalDur, pt.stepCalls),
msAvg(stepResidual, pt.stepCalls),
msAvg(pt.sampleIntDur, pt.sampleInts),
)
}
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
if usePromptCache {
insertCache()
logMemory("request_done_cached", -1)
return
}
freeCaches()
logMemory("request_done_freed", -1)
}
func recordCacheCheckpoints(caches []cache.Cache, pos int) {
if pos <= 0 {
return
}
for _, c := range caches {
if c == nil {
continue
}
if recorder, ok := c.(cache.CheckpointRecorder); ok {
recorder.RecordCheckpoint(pos)
}
}
}
func freeOwnedCaches(caches []cache.Cache) {
for i, c := range caches {
if c == nil {
continue
}
c.Free()
caches[i] = nil
}
}
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
@@ -31,7 +262,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
usePromptCache := true
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
usePromptCache = false
}
lowMemoryDecode := !usePromptCache
if m, ok := r.Model.(interface{ LowMemoryDecode() bool }); ok {
lowMemoryDecode = m.LowMemoryDecode()
}
prefillChunk := prefillChunkSize(lowMemoryDecode)
var caches []cache.Cache
var tokens []int32
if usePromptCache {
caches, tokens = r.FindNearestCache(inputs)
} else {
tokens = inputs
}
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
@@ -43,40 +291,140 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
materializeRecurrentCaches := func() bool {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
if c == nil {
continue
}
if _, ok := c.(*cache.RecurrentCache); !ok {
continue
}
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return false
}
mlx.Eval(state...)
return true
}
freeCaches := func() {
// Non-prompt-cache requests allocate fresh caches every generation.
// Explicitly free cache-owned state (including recurrent checkpoints),
// then sweep remaining intermediates.
freeOwnedCaches(caches)
mlx.Sweep()
mlx.ClearCache()
}
debugMemory := mlxDebugMemoryEnabled()
hasRecurrent := hasRecurrentCaches(caches)
asyncRecurrentMaterializeEvery := recurrentMaterializeInterval(lowMemoryDecode, hasRecurrent)
computeStepLogprobs := mlxComputeLogprobsEnabled()
pipelineTiming := newPipelineTiming()
logMemory := func(phase string, token int) {
if !debugMemory {
return
}
if token >= 0 {
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
return
}
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
}
logMemory("prefill_start", -1)
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
n := min(prefillChunk, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
materializeCaches()
recordCacheCheckpoints(caches, processed+n)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
logMemory("prefill_done", -1)
step := func(token *mlx.Array, async bool) (*mlx.Array, *mlx.Array) {
var t0, t time.Time
var forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur time.Duration
if pipelineTiming != nil {
t0 = time.Now()
t = t0
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
if pipelineTiming != nil {
forwardDur = time.Since(t)
t = time.Now()
}
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sample(logprobs)
logits := r.Model.Unembed(fwd)
if pipelineTiming != nil {
unembedDur = time.Since(t)
t = time.Now()
}
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
if pipelineTiming != nil {
sliceDur = time.Since(t)
t = time.Now()
}
var logprobs *mlx.Array
sampleInput := logits
if computeStepLogprobs {
logprobs = logits.Subtract(logits.Logsumexp(true))
sampleInput = logprobs
}
if pipelineTiming != nil {
logprobsDur = time.Since(t)
t = time.Now()
}
sample := request.Sample(sampleInput)
if pipelineTiming != nil {
sampleDur = time.Since(t)
t = time.Now()
}
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
if pipelineTiming != nil {
pinSweepDur = time.Since(t)
}
if async {
mlx.AsyncEval(sample, logprobs)
if pipelineTiming != nil {
asyncEvalDur = time.Since(t)
}
}
if pipelineTiming != nil {
pipelineTiming.recordStep(async, time.Since(t0), forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur)
}
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed), !lowMemoryDecode)
if lowMemoryDecode {
// Materialize cache updates to prevent transform graph growth.
materializeCaches()
}
recordCacheCheckpoints(caches, total)
logMemory("decode_init", -1)
var b bytes.Buffer
@@ -84,20 +432,34 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
var nextSample, nextLogprobs *mlx.Array
if !lowMemoryDecode {
nextSample, nextLogprobs = step(sample, true)
}
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
logMemory("decode_first_eval", i)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
var intWaitStart time.Time
if pipelineTiming != nil {
intWaitStart = time.Now()
}
output := int32(sample.Int())
if pipelineTiming != nil {
pipelineTiming.recordSampleInt(time.Since(intWaitStart), len(outputs)+1)
}
outputs = append(outputs, output)
if !lowMemoryDecode {
recordCacheCheckpoints(caches, total+len(outputs))
}
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
mlx.Unpin(sample, logprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -109,18 +471,53 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
Token: int(output),
}
// For recurrent linear-attention models, avoid async prefetch to reduce
// peak memory and clear allocator cache every token.
if lowMemoryDecode {
mlx.Unpin(sample, logprobs)
mlx.Sweep()
if i+1 >= request.Options.MaxTokens {
break
}
mlx.ClearCache()
sample, logprobs = step(mlx.FromValues([]int32{output}, 1), false)
// Materialize cache updates to avoid unbounded transform chains.
materializeCaches()
recordCacheCheckpoints(caches, total+len(outputs))
if i%32 == 0 {
logMemory("decode_lowmem_step", i)
}
continue
}
mlx.Unpin(sample, logprobs)
if asyncRecurrentMaterializeEvery > 0 && (i+1)%asyncRecurrentMaterializeEvery == 0 {
if materializeRecurrentCaches() {
mlx.Sweep()
logMemory("decode_async_recurrent_materialize", i)
}
}
if i%256 == 0 {
mlx.ClearCache()
}
if i%64 == 0 {
logMemory("decode_async_step", i)
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Unpin(sample, logprobs)
if pipelineTiming != nil {
pipelineTiming.maybeEmit(true, len(outputs))
}
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
finalizeRequestCaches(usePromptCache,
func() { r.InsertCache(append(inputs, outputs...), caches) },
freeCaches,
logMemory,
)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
@@ -129,7 +526,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
r.cache.LogCache()
}
}
return nil
}

View File

@@ -0,0 +1,209 @@
//go:build mlx
package mlxrunner
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type stubCache struct {
freeCalls int
}
func (s *stubCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
func (s *stubCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
func (s *stubCache) Materialize() []*mlx.Array { return nil }
func (s *stubCache) CanTrim() bool { return true }
func (s *stubCache) Trim(int) int { return 0 }
func (s *stubCache) Clone() cache.Cache { return s }
func (s *stubCache) Free() { s.freeCalls++ }
func (s *stubCache) Offset() int { return 0 }
func (s *stubCache) Len() int { return 0 }
func TestPrefillChunkSize(t *testing.T) {
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
if got := prefillChunkSize(false); got != 2<<10 {
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
}
if got := prefillChunkSize(true); got != 32 {
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
}
}
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
if got := prefillChunkSize(false); got != 96 {
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
}
if got := prefillChunkSize(true); got != 96 {
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
}
}
func TestMLXDebugMemoryEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
if mlxDebugMemoryEnabled() {
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
if !mlxDebugMemoryEnabled() {
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
}
}
func TestHasRecurrentCaches(t *testing.T) {
if hasRecurrentCaches(nil) {
t.Fatal("hasRecurrentCaches(nil) = true, want false")
}
if hasRecurrentCaches([]cache.Cache{cache.NewKVCache()}) {
t.Fatal("hasRecurrentCaches(kv-only) = true, want false")
}
rc := cache.NewRecurrentCache(4, 8, 2, 16, 8)
if !hasRecurrentCaches([]cache.Cache{cache.NewKVCache(), rc}) {
t.Fatal("hasRecurrentCaches(mixed) = false, want true")
}
}
func TestRecurrentMaterializeInterval(t *testing.T) {
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "")
if got := recurrentMaterializeInterval(true, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(lowmem=true, recurrent=true) = %d, want 0", got)
}
if got := recurrentMaterializeInterval(false, false); got != 0 {
t.Fatalf("recurrentMaterializeInterval(lowmem=false, recurrent=false) = %d, want 0", got)
}
if got := recurrentMaterializeInterval(false, true); got != defaultRecurrentMaterializeInterval {
t.Fatalf("recurrentMaterializeInterval(default) = %d, want %d", got, defaultRecurrentMaterializeInterval)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "16")
if got := recurrentMaterializeInterval(false, true); got != 16 {
t.Fatalf("recurrentMaterializeInterval(env=16) = %d, want 16", got)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "0")
if got := recurrentMaterializeInterval(false, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(env=0) = %d, want 0", got)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "-1")
if got := recurrentMaterializeInterval(false, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(env=-1) = %d, want 0", got)
}
}
func TestMLXPipelineTimingConfig(t *testing.T) {
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "")
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "")
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
t.Fatalf("mlxPipelineTimingConfig() = (%v, %d), want (false, 0)", enabled, every)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "1")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
t.Fatalf("mlxPipelineTimingConfig(enabled default) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "16")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != 16 {
t.Fatalf("mlxPipelineTimingConfig(enabled env=16) = (%v, %d), want (true, 16)", enabled, every)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "0")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
t.Fatalf("mlxPipelineTimingConfig(enabled env=0) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "0")
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
t.Fatalf("mlxPipelineTimingConfig(disabled) = (%v, %d), want (false, 0)", enabled, every)
}
}
func TestMLXComputeLogprobsEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "")
if mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "1")
if !mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = false with env=1, want true")
}
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "0")
if mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = true with env=0, want false")
}
}
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
insertCalls := 0
freeCalls := 0
logPhase := ""
finalizeRequestCaches(
true,
func() { insertCalls++ },
func() { freeCalls++ },
func(phase string, _ int) { logPhase = phase },
)
if insertCalls != 1 {
t.Fatalf("insert calls = %d, want 1", insertCalls)
}
if freeCalls != 0 {
t.Fatalf("free calls = %d, want 0", freeCalls)
}
if logPhase != "request_done_cached" {
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
}
}
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
insertCalls := 0
freeCalls := 0
logPhase := ""
finalizeRequestCaches(
false,
func() { insertCalls++ },
func() { freeCalls++ },
func(phase string, _ int) { logPhase = phase },
)
if insertCalls != 0 {
t.Fatalf("insert calls = %d, want 0", insertCalls)
}
if freeCalls != 1 {
t.Fatalf("free calls = %d, want 1", freeCalls)
}
if logPhase != "request_done_freed" {
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
}
}
func TestFreeOwnedCaches(t *testing.T) {
a := &stubCache{}
b := &stubCache{}
caches := []cache.Cache{a, nil, b}
freeOwnedCaches(caches)
if a.freeCalls != 1 {
t.Fatalf("a free calls = %d, want 1", a.freeCalls)
}
if b.freeCalls != 1 {
t.Fatalf("b free calls = %d, want 1", b.freeCalls)
}
if caches[0] != nil || caches[2] != nil {
t.Fatalf("cache entries not nilled after free: %#v", caches)
}
}

View File

@@ -62,6 +62,39 @@ type Runner struct {
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
caches []*HybridCacheEntry
}
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
if len(tensors) == 0 {
return 0, 0
}
seen := make(map[*mlx.Array]bool, len(tensors))
toRelease := make([]*mlx.Array, 0, len(tensors))
for name, arr := range tensors {
if arr == nil || !arr.Valid() {
delete(tensors, name)
continue
}
if keep != nil {
if _, ok := keep[arr]; ok {
continue
}
}
delete(tensors, name)
if seen[arr] {
continue
}
seen[arr] = true
toRelease = append(toRelease, arr)
}
if len(toRelease) == 0 {
return 0, 0
}
return len(toRelease), mlx.Release(toRelease...)
}
func (r *Runner) Load(modelName string) error {
@@ -85,9 +118,33 @@ func (r *Runner) Load(modelName string) error {
// Assign weights to model (model-specific logic)
loadWeights := base.Weights(m)
if err := loadWeights(tensors); err != nil {
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
}
mlx.Sweep()
mlx.ClearCache()
return err
}
// Materialize model-owned roots before releasing source tensor handles, then
// pin only those roots. This avoids retaining large load-time intermediates
// while still protecting shared model tensors from Sweep.
roots := mlx.Collect(m)
mlx.Eval(roots...)
mlx.Pin(roots...)
keep := make(map[*mlx.Array]struct{})
for _, arr := range roots {
if arr != nil && arr.Valid() {
keep[arr] = struct{}{}
}
}
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
}
mlx.Sweep()
mlx.ClearCache()
r.Model = m
r.Tokenizer = m.Tokenizer()
return nil

View File

@@ -15,6 +15,40 @@ type LinearLayer interface {
OutputDim() int32
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array

1913
x/models/qwen3_5/qwen3_5.go Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,206 @@
//go:build mlx
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestModelRuntimeToggles(t *testing.T) {
m := &Model{}
if m.DisablePromptCache() {
t.Fatal("DisablePromptCache() = true, want false")
}
if m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = true, want false")
}
if !m.EnableCompile() {
t.Fatal("EnableCompile() = false, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "0")
if m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "1")
if !m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = false with env override 1, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "0")
if m.EnableCompile() {
t.Fatal("EnableCompile() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "1")
if !m.EnableCompile() {
t.Fatal("EnableCompile() = false with env override, want true")
}
if !qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = false, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "0")
if qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "1")
if !qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = false with env override 1, want true")
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}

View File

@@ -0,0 +1,16 @@
//go:build mlx
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}