Compare commits

..

2 Commits

Author SHA1 Message Date
Jeffrey Morgan
da70c3222e model: support for qwen3.5 architecture (#14378) 2026-02-24 20:08:05 -08:00
Bruce MacDonald
9d902d63ce ggml: ensure tensor size is valid (#14406)
When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
2026-02-24 21:52:44 -04:00
53 changed files with 2042 additions and 6258 deletions

View File

@@ -320,7 +320,7 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &lfm2Model{}
case "Lfm2VlForConditionalGeneration":
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM":
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}

View File

@@ -1,6 +1,7 @@
package convert
import (
"encoding/json"
"fmt"
"io/fs"
"math"
@@ -13,8 +14,21 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
type qwen3NextRopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
MropeSection []int32 `json:"mrope_section"`
}
type qwen3NextRopeParams struct {
MRopeInterleaved bool `json:"mrope_interleaved"`
MropeSection []int32 `json:"mrope_section"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
}
type qwen3NextTextConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
NormTopkProb *bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
FullAttentionInterval uint32 `json:"full_attention_interval"`
LayerTypes []string `json:"layer_types"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
@@ -43,16 +58,102 @@ type qwen3NextModel struct {
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
}
type qwen3NextVisionConfig struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_channels"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
type qwen3NextModel struct {
ModelParameters
qwen3NextTextConfig
TextConfig *qwen3NextTextConfig `json:"text_config"`
VisionModel qwen3NextVisionConfig `json:"vision_config"`
ImageTokenID uint32 `json:"image_token_id"`
VisionStartTokenID uint32 `json:"vision_start_token_id"`
VisionEndTokenID uint32 `json:"vision_end_token_id"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
if q.TextConfig != nil {
q.qwen3NextTextConfig = *q.TextConfig
}
if q.RopeTheta == 0 {
q.RopeTheta = q.RopeParameters.RopeTheta
}
if q.PartialRotaryFactor == 0 {
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
}
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
q.RopeScaling.Type = q.RopeParameters.RopeType
}
// Pull vision preprocessing fields when present.
if q.VisionModel.Depth > 0 {
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
var pre struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
if json.Unmarshal(bts, &pre) == nil {
if q.VisionModel.PatchSize == 0 {
q.VisionModel.PatchSize = pre.PatchSize
}
if q.VisionModel.TemporalPatchSize == 0 {
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
}
if q.VisionModel.SpatialMergeSize == 0 {
q.VisionModel.SpatialMergeSize = pre.MergeSize
}
if q.VisionModel.Size.ShortestEdge == 0 {
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge == 0 {
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) == 0 {
q.VisionModel.ImageMean = pre.ImageMean
}
if len(q.VisionModel.ImageStd) == 0 {
q.VisionModel.ImageStd = pre.ImageStd
}
}
}
}
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
if _, err := q.kvHeadCounts(); err != nil {
return err
}
return nil
}
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
if len(q.LayerTypes) > 0 {
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
hasRecurrent := false
for i := range q.NumHiddenLayers {
layerType := ""
if i < uint32(len(q.LayerTypes)) {
layerType = q.LayerTypes[i]
}
if layerType == "full_attention" {
kv[i] = q.NumKeyValueHeads
hasFull = true
} else {
hasRecurrent = true
}
}
if !hasFull || !hasRecurrent {
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
}
return kv, nil
}
if q.FullAttentionInterval == 0 {
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
kv[i] = q.NumKeyValueHeads
hasFull = true
}
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return kv, nil
}
func (q *qwen3NextModel) ropeSections() []int32 {
if len(q.RopeParameters.MropeSection) > 0 {
return q.RopeParameters.MropeSection
}
return q.RopeScaling.MropeSection
}
func (q *qwen3NextModel) shouldReorderVHeads() bool {
modelType := strings.ToLower(q.ModelType)
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
return false
}
for _, arch := range q.Architectures {
arch = strings.ToLower(arch)
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
return false
}
}
// Default to qwen3.5 layout for all other qwen3next-family imports.
return true
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
arch := "qwen35"
if q.NumExperts > 0 {
arch = "qwen35moe"
}
kv["general.architecture"] = arch
kv["tokenizer.ggml.pre"] = "qwen35"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if sections := q.ropeSections(); len(sections) > 0 {
kv["mrope_sections"] = sections
kv["rope.mrope_section"] = sections
kv["rope.dimension_sections"] = sections
}
if q.RopeParameters.MRopeInterleaved {
kv["rope.mrope_interleaved"] = true
}
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.NormTopkProb != nil {
kv["norm_top_k_prob"] = *q.NormTopkProb
}
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.state_size"] = q.LinearKeyHeadDim
kv["ssm.group_count"] = q.LinearNumKeyHeads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
if q.shouldReorderVHeads() {
kv["ssm.v_head_reordered"] = true
}
if q.FullAttentionInterval > 0 {
kv["full_attention_interval"] = q.FullAttentionInterval
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
if headCounts, err := q.kvHeadCounts(); err == nil {
kv["attention.head_count_kv"] = headCounts
}
if q.VisionModel.Depth > 0 {
kv["vision.block_count"] = q.VisionModel.Depth
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
kv["vision.num_channels"] = q.VisionModel.InChannels
if q.VisionModel.PatchSize > 0 {
kv["vision.patch_size"] = q.VisionModel.PatchSize
}
if q.VisionModel.SpatialMergeSize > 0 {
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
}
if q.VisionModel.RMSNormEps > 0 {
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
}
if q.VisionModel.RopeTheta > 0 {
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
}
if q.VisionModel.TemporalPatchSize > 0 {
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
}
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
if q.VisionModel.Size.ShortestEdge > 0 {
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge > 0 {
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) > 0 {
kv["vision.image_mean"] = q.VisionModel.ImageMean
}
if len(q.VisionModel.ImageStd) > 0 {
kv["vision.image_std"] = q.VisionModel.ImageStd
}
}
if q.ImageTokenID > 0 {
kv["image_token_id"] = q.ImageTokenID
}
if q.VisionStartTokenID > 0 {
kv["vision_start_token_id"] = q.VisionStartTokenID
}
if q.VisionEndTokenID > 0 {
kv["vision_end_token_id"] = q.VisionEndTokenID
}
return kv
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
out = append(out, slices.Collect(splitDim(t, 1,
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
))...)
case strings.Contains(name, ".mlp.experts.down_proj"):
out = append(out, &ggml.Tensor{
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
WriterTo: t,
})
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
t.SetRepacker(q.repackSSMA())
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_qkv.weight"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackAttnQKV())
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_gate.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_dt"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_out.weight"):
if q.shouldReorderVHeads() {
// HF out_proj layout is [out_features, in_features]; reorder columns.
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackConv1D())
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
}
}
return out
}
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
}
}
func (q *qwen3NextModel) repackAttnQKV() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() || len(shape) != 2 {
return data, nil
}
rows := int(shape[0])
cols := int(shape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qDim := headK * numK
kDim := headK * numK
vDim := headV * numV
qkvDim := qDim + kDim + vDim
switch {
case rows == qkvDim:
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
// reorder only V rows from grouped -> tiled head layout.
out := make([]float32, len(data))
qkRows := qDim + kDim
qkSize := qkRows * cols
copy(out[:qkSize], data[:qkSize])
vStart := qkSize
vEnd := vStart + vDim*cols
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
copy(out[vEnd:], data[vEnd:])
return out, nil
case cols == qkvDim:
// Fallback for already-transposed [in_features, out_features] tensors.
out := make([]float32, len(data))
copy(out, data)
for r := range rows {
base := r * cols
vStart := base + qDim + kDim
vEnd := vStart + vDim
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackConv1D() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
normShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
normShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
normShape = []uint64{shape[0], shape[2]}
}
}
if len(normShape) != 2 {
return data, nil
}
rows := int(normShape[0])
cols := int(normShape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qkChannels := 2 * headK * numK
totalChannels := qkChannels + headV*numV
if qkChannels <= 0 {
return data, nil
}
switch {
case rows == totalChannels:
// HF layout after squeeze: [channels, kernel]
out := make([]float32, len(data))
prefix := qkChannels * cols
copy(out[:prefix], data[:prefix])
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[prefix:], reorderedV)
return out, nil
case cols == totalChannels:
// Fallback for transposed [kernel, channels]
out := make([]float32, len(data))
copy(out, data)
vChannels := totalChannels - qkChannels
for r := range rows {
base := r * cols
vStart := base + qkChannels
vEnd := vStart + vChannels
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackSSMA() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
result := make([]float32, len(data))
for i, v := range data {
result[i] = -float32(math.Exp(float64(v)))
}
if !q.shouldReorderVHeads() {
return result, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
}
}
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
return data, nil
}
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
if dim < 0 {
dim += len(dims)
}
if dim < 0 || dim >= len(dims) {
return data, nil
}
expected := numKHeads * numVPerK * headDim
if dims[dim] != expected {
return data, nil
}
newShape := make([]int, 0, len(dims)+2)
newShape = append(newShape, dims[:dim]...)
newShape = append(newShape, numKHeads, numVPerK, headDim)
newShape = append(newShape, dims[dim+1:]...)
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := tt.Reshape(newShape...); err != nil {
return nil, err
}
perm := make([]int, len(newShape))
for i := range perm {
perm[i] = i
}
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
tt, err := tensor.Transpose(tt, perm...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
total := 1
for _, d := range dims {
total *= d
}
if err := tt.Reshape(total); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
type qkvzSplitSpec struct {
hidden int
headKDim int
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Vision
"model.visual", "v",
"patch_embed.proj", "patch_embed",
"blocks", "blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"deepstack_merger_list", "deepstack_merger",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
// Linear attention (legacy qwen3next)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
// Linear attention (qwen35)
"linear_attn.in_proj_qkv", "attn_qkv",
"linear_attn.in_proj_z", "attn_gate",
"linear_attn.in_proj_a", "ssm_alpha",
"linear_attn.in_proj_b", "ssm_beta",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
// MoE
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
// Dense FFN
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",

View File

@@ -0,0 +1,563 @@
package convert
import (
"bytes"
"encoding/binary"
"os"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/fs/ggml"
)
func boolPtr(v bool) *bool {
return &v
}
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
t.Helper()
var b bytes.Buffer
if _, err := tensor.WriteTo(&b); err != nil {
t.Fatal(err)
}
numel := 1
for _, d := range tensor.Shape {
numel *= int(d)
}
values := make([]float32, numel)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
return values
}
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
}
}
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
Architectures: []string{"Qwen3NextForCausalLM"},
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
}
}
func TestQwen3NextKVLegacyConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 8192,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
NormTopkProb: boolPtr(true),
MoEIntermediateSize: 256,
SharedExpertIntermSize: 512,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if _, ok := kv["ssm.v_head_reordered"]; ok {
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
}
if got, want := kv["norm_top_k_prob"], true; got != want {
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
}
}
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 4096,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if _, ok := kv["norm_top_k_prob"]; ok {
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
}
}
func TestQwen35KVFromTextConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
TextConfig: &qwen3NextTextConfig{
MaxPositionEmbeddings: 16384,
HiddenSize: 1024,
NumHiddenLayers: 4,
IntermediateSize: 4096,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
HeadDim: 128,
RMSNormEPS: 1e-6,
LayerTypes: []string{
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
},
LinearConvKernelDim: 4,
LinearKeyHeadDim: 128,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 128,
RopeParameters: qwen3NextRopeParams{
MRopeInterleaved: true,
MropeSection: []int32{11, 11, 10},
RopeType: "default",
RopeTheta: 10_000_000,
PartialRotaryFactor: 0.25,
},
},
VisionModel: qwen3NextVisionConfig{
Depth: 2,
HiddenSize: 128,
NumHeads: 4,
InChannels: 3,
PatchSize: 16,
SpatialMergeSize: 2,
RMSNormEps: 1e-6,
RopeTheta: 10_000,
TemporalPatchSize: 2,
DeepstackVisualIndexes: []int32{1},
},
ImageTokenID: 1001,
VisionStartTokenID: 1002,
VisionEndTokenID: 1003,
}
m.VisionModel.Size.ShortestEdge = 224
m.VisionModel.Size.LongestEdge = 4096
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
}
mrope, ok := kv["mrope_sections"].([]int32)
if !ok {
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
}
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
}
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
if !ok {
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
}
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
}
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
}
if got, want := kv["vision.block_count"], uint32(2); got != want {
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
}
}
func TestQwen3NextReplacements(t *testing.T) {
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
}
}
func TestQwen35ReordersVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected data: got %v want %v", got, want)
}
}
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_qkv.weight",
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_out.weight",
shape: []uint64{2, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_beta.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
}
}
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
}
}
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 1},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
}
}
func TestQwen35MoePackedExperts(t *testing.T) {
m := &qwen3NextModel{
qwen3NextTextConfig: qwen3NextTextConfig{
NumHiddenLayers: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.mlp.experts.gate_up_proj",
shape: []uint64{2, 4, 3},
data: []float32{
0, 1, 2,
3, 4, 5,
6, 7, 8,
9, 10, 11,
12, 13, 14,
15, 16, 17,
18, 19, 20,
21, 22, 23,
},
},
&fakeTensor{
name: "blk.0.mlp.experts.down_proj",
shape: []uint64{2, 5, 3},
data: make([]float32, 2*5*3),
},
})
get := func(name string) *ggml.Tensor {
for _, tensor := range out {
if tensor.Name == name {
return tensor
}
}
return nil
}
gate := get("blk.0.ffn_gate_exps.weight")
if gate == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
}
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, gate), []float32{
0, 1, 2, 3, 4, 5,
12, 13, 14, 15, 16, 17,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate values: got %v want %v", got, want)
}
up := get("blk.0.ffn_up_exps.weight")
if up == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
}
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected up shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, up), []float32{
6, 7, 8, 9, 10, 11,
18, 19, 20, 21, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected up values: got %v want %v", got, want)
}
down := get("blk.0.ffn_down_exps.weight")
if down == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
}
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected down shape: got %v want %v", got, want)
}
}
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
m := &qwen3NextModel{}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ffn_gate_inp_shexp.weight",
shape: []uint64{1, 4},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
}
}

View File

@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
t.Pre = "deepseek-coder"
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
t.Pre = "qwen2"
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
t.Pre = "qwen35"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer
default:

View File

@@ -386,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "qwen35 pretokenizer",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
}
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{Model: "gpt2"},
Pre: "qwen35",
},
},
}
for _, tt := range cases {

View File

@@ -290,6 +290,7 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen35", "qwen35moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
@@ -868,7 +869,12 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
arch := f.KV().Architecture()
if slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
return true
}
if slices.Contains([]string{"gemma2"}, arch) {
return false
}
@@ -892,6 +898,7 @@ func (f GGML) FlashAttention() bool {
"nemotron_h", "nemotron_h_moe",
"olmo3",
"qwen3", "qwen3moe",
"qwen35", "qwen35moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -245,7 +245,22 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
padding := ggufPadding(offset, int64(alignment))
llm.tensorOffset = uint64(offset + padding)
// get file size to validate tensor bounds
fileSize, err := rs.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to determine file size: %w", err)
}
if _, err := rs.Seek(offset, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek back after size check: %w", err)
}
for _, tensor := range llm.tensors {
tensorEnd := llm.tensorOffset + tensor.Offset + tensor.Size()
if tensorEnd > uint64(fileSize) {
return fmt.Errorf("tensor %q offset+size (%d) exceeds file size (%d)", tensor.Name, tensorEnd, fileSize)
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to get current offset: %w", err)

View File

@@ -11,21 +11,21 @@ import (
)
func TestWriteGGUF(t *testing.T) {
b := bytes.NewBuffer(make([]byte, 2*3))
tensorData := make([]byte, 2*3*4) // 6 F32 elements = 24 bytes
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
}
rand.Shuffle(len(ts), func(i, j int) {
@@ -98,4 +98,32 @@ func TestWriteGGUF(t *testing.T) {
}
})
}
t.Run("truncated_tensor_data", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "blk.0.attn.weight", Kind: 0, Shape: []uint64{512, 2}, WriterTo: bytes.NewBuffer(make([]byte, 32))},
}
w, err := os.CreateTemp(t.TempDir(), "truncated_*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{"general.architecture": "test"}, ts); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
if _, err := Decode(r, -1); err == nil {
t.Error("Decode should reject GGUF files where tensor data extends beyond file size")
}
})
}

View File

@@ -11,9 +11,9 @@ import (
)
const (
DefaultCheckpointCount = 32
DefaultCheckpointCount = 24
DefaultCheckpointMinPos = int32(16)
DefaultCheckpointInterval = int32(1280)
DefaultCheckpointInterval = int32(1664)
)
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")

View File

@@ -2,595 +2,58 @@ package qwen3next
import (
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache stores:
// - a standard causal KV cache for full attention layers
// - per-sequence conv state for linear attention layers
// - per-sequence delta state for linear attention layers
//
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int // convKernelSize - 1
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
// Delta state dimensions
deltaStateSize int // headVDim * headVDim * numVHeads
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer delta state buffers (allocated lazily)
deltaCtxs map[int]ml.Context
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointDeltaCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
*kvcache.Recurrent
}
func NewHybridCache(
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
convDim, convChannels, deltaStateSize int,
) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
convDim: convDim,
convChannels: convChannels,
deltaStateSize: deltaStateSize,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
deltaCtxs: make(map[int]ml.Context),
deltaStates: make(map[int]ml.Tensor),
checkpointCount: checkpointCountDefault,
checkpointMinPos: checkpointMinPosDefault,
checkpointInterval: checkpointIntervalDefault,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointDeltaCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: deltaStateSize,
CheckpointLogPrefix: "qwen3next",
})
return &HybridCache{Recurrent: base}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.deltaCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointDeltaCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for recurrent layers
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
c.reserveCheckpoints = true
c.planCheckpoints(batch)
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero state for newly allocated slots
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = false
c.planCheckpoints(batch)
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros the recurrent state for the given slots across all layers.
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Zero conv states
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// Zero delta states
if len(c.deltaStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
for _, buf := range c.deltaStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
for _, buf := range c.deltaStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// Copy-on-write for recurrent state
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return kvcache.ErrNotSupported
}
if !c.restoreComplete(restore) {
return kvcache.ErrNotSupported
}
// If the recurrent slot is shared, detach it before applying a restore.
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
// Removal invalidates recurrent state
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *HybridCache) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.deltaStates[layer]; ok {
return buf
}
if _, ok := c.deltaCtxs[layer]; !ok {
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent delta state must stay in F32.
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
c.deltaStates[layer] = buf
return buf
}
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
// DeltaState returns the delta state for current batch sequences as
// [headVDim, headVDim*numVHeads, nSeqs].
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
}
// UpdateDeltaState writes a new delta state for current batch sequences.
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}
// Keep qwen3next behavior for partial mid-sequence removals.
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
return c.Recurrent.Remove(seq, beginIndex, endIndex)
}

View File

@@ -1,498 +0,0 @@
package qwen3next
import (
"log/slog"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
checkpointCountDefault = 32
checkpointMinPosDefault = int32(16)
checkpointIntervalDefault = int32(1280)
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
delta map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *HybridCache) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return kvcache.ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.delta {
buf := c.deltaBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.delta) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.deltaStates {
if entry.delta == nil || entry.delta[layer] == nil {
return false
}
}
return true
}
func (c *HybridCache) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.delta != nil {
if dstEntry.delta == nil {
dstEntry.delta = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.delta {
dst := c.ensureCheckpointDelta(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
if entry.delta == nil {
entry.delta = make(map[int]ml.Tensor)
}
if t, ok := entry.delta[layer]; ok {
return t
}
ctx, ok := c.checkpointDeltaCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointDeltaCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
entry.delta[layer] = t
return t
}
func (c *HybridCache) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *HybridCache) reserveCheckpointDelta(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointDelta(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -1,300 +0,0 @@
package qwen3next
import (
"errors"
"math"
"os"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
func newTestBackend(tb testing.TB) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
if err != nil {
tb.Fatal(err)
}
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
_ = f.Close()
tb.Fatal(err)
}
if err := f.Close(); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
b.Close()
})
return b
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCachePrepareRestore(t *testing.T) {
cache := NewHybridCache(nil, 1, 1, 1)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
backend := newTestBackend(t)
cache := NewHybridCache(nil, 1, 2, 2)
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
cache.slotForSeq[1] = 0
cache.slotForSeq[2] = 0
cache.refCount[0] = 2
cache.refCount[1] = 0
cache.freeSlots = []int{1}
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
t.Fatalf("Remove failed: %v", err)
}
if cache.slotForSeq[1] == cache.slotForSeq[2] {
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
}
if cache.slotForSeq[1] != 1 {
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
}
if cache.slotForSeq[2] != 0 {
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
}
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared")
}
}
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
// Only set conv checkpoint, not delta - making it incomplete
entry.conv = map[int]ml.Tensor{0: nil}
// entry.delta is not set, so checkpoint is incomplete
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, kvcache.ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Don't set convStates/deltaStates - with no layers to check,
// entryComplete will return true as long as entry.pos >= 0
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
// Test that restoreComplete returns true when no layers need checkpoints
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
// Test that ring buffer wrap-around reuses entries without clearing maps.
store := newSlotCheckpointStore(3)
// Fill the buffer
store.record(10)
store.record(20)
store.record(30)
// Create fake tensor data in the first entry's maps
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil // Simulated tensor reference
store.entries[0].delta = make(map[int]ml.Tensor)
store.entries[0].delta[0] = nil // Simulated tensor reference
// Record another entry, which should wrap around and overwrite entry 0
store.record(40)
// Verify the maps are still present (we reuse tensors)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].delta == nil {
t.Fatalf("expected delta map to be preserved on reuse")
}
// Verify the new position was recorded
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
// Test behavior when buffer is exactly at capacity
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
// Verify both checkpoints are accessible
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
// Test behavior with zero-size buffer
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
// Test pruning that removes all checkpoints
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
// Prune everything by setting threshold below all positions
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
// When all checkpoints are pruned, lastPos is reset to -1
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -37,7 +37,9 @@ type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
@@ -96,7 +98,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
@@ -106,24 +107,40 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
var beta ml.Tensor
var alpha ml.Tensor
switch {
case gdn.SSMBetaAlpha != nil:
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
// Split beta and alpha
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Reshape to merge head dimensions
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
// Keep beta layout consistent with qwen35 and llama.cpp:
// [1, numVHeads, nSeqTokens, nSeqs]
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
// qwen35 path: beta/alpha are separate projections.
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
alphaSoftplus := alphaBiased.Softplus(ctx)
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
@@ -172,16 +189,20 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
// Repeat interleave Q and K if numKHeads != numVHeads
if numKHeads != numVHeads {
repeatFactor := numVHeads / numKHeads
if opts.vHeadReordered {
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
} else {
repeatFactor := numVHeads / numKHeads
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
}
}
// Choose computation mode based on sequence length
@@ -189,7 +210,9 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
if opts.masks == nil {
opts.masks = createMasks(ctx)
}
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
}
@@ -310,9 +333,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// Match llama.cpp delta-net-base layout:
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Compute padding
@@ -324,7 +348,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q = q.Pad(ctx, 0, pad, 0, 0)
k = k.Pad(ctx, 0, pad, 0, 0)
v = v.Pad(ctx, 0, pad, 0, 0)
gate = gate.Pad(ctx, pad, 0, 0, 0)
gate = gate.Pad(ctx, 0, pad, 0, 0)
beta = beta.Pad(ctx, 0, pad, 0, 0)
}
@@ -344,10 +368,12 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
// Reshape gate and cumsum over chunk axis.
// [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs]
gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
// g_cumsum = cumsum(gate)
gCumsum := gate.CumSum(ctx)
gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx)
// Compute decay mask
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)

View File

@@ -1,9 +1,12 @@
package qwen3next
import (
"bytes"
"cmp"
"fmt"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@@ -11,6 +14,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/models/qwen3vl"
"github.com/ollama/ollama/tokenizer"
)
@@ -41,10 +45,15 @@ type Options struct {
ssmNGroup int // num_k_heads
ssmDtRank int // num_v_heads
convKernelSize int // SSM conv kernel size
vHeadReordered bool
// Per-layer type from GGUF metadata
isRecurrent []bool
// RoPE mode config (used by qwen35/qwen35moe)
mropeSections []int
mropeInterleaved bool
// Pre-computed masks for chunked attention (created once per forward pass)
masks *Masks
}
@@ -54,7 +63,17 @@ func (o Options) headDim() int {
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
var opts []func(*rope.Options)
if len(o.mropeSections) > 0 {
if o.mropeInterleaved {
opts = append(opts, rope.WithInterleaveMRoPE(o.mropeSections))
} else {
opts = append(opts, rope.WithMRoPE(o.mropeSections))
}
} else {
opts = append(opts, rope.WithTypeNeoX())
}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
@@ -214,20 +233,190 @@ type Model struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
Layers []Layer `gguf:"blk"`
Vision *qwen3vl.VisionModel `gguf:"v"`
ImageProcessor *qwen3vl.ImageProcessor
*Options
positionCache []int32
imageToken int32
visionStart int32
visionEnd int32
spatialMergeSize uint32
}
func (m *Model) mapPosition(id int32) int32 {
if id < int32(len(m.positionCache)) {
return m.positionCache[id]
}
if len(m.positionCache) > 0 {
return id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
}
return id
}
func (m *Model) buildPositions(ctx ml.Context, batch input.Batch) ml.Tensor {
if len(m.mropeSections) == 0 {
return ctx.Input().FromInts(batch.Positions, len(batch.Positions))
}
// ggml MRoPE expects [time, height, width, extra] for each token.
positionSlice := [][]int32{
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
}
for i, id := range batch.Positions {
p := m.mapPosition(id)
positionSlice[0][i] = p
positionSlice[1][i] = p
positionSlice[2][i] = p
}
if m.Vision != nil {
for _, mi := range batch.Multimodal {
grid, ok := mi.Multimodal[0].Data.(*qwen3vl.Grid)
if !ok {
continue
}
w := max(1, grid.Width/int(m.spatialMergeSize))
for i := range mi.Multimodal[0].Tensor.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
return ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if m.Vision == nil || m.ImageProcessor == nil || len(m.Vision.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelValues, grid, err := m.ImageProcessor.ProcessImage(ctx, img)
if err != nil {
return nil, err
}
visionOutputs, deepstackVisualEmbeds := m.Vision.Forward(ctx, pixelValues, grid)
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
for i := range deepstackVisualEmbeds {
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
}
return mm, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
m.positionCache = m.positionCache[:0]
var result []*input.Input
appendInput := func(inp *input.Input, position int32) {
result = append(result, inp)
m.positionCache = append(m.positionCache, position)
}
var p int32
for _, inp := range inputs {
if inp.Multimodal == nil {
appendInput(inp, p)
p++
continue
}
grid := inp.Multimodal[0].Data.(*qwen3vl.Grid)
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
appendInput(&input.Input{
Token: m.visionStart,
SameBatch: tokensPerGrid + 1,
}, p)
p++
appendInput(&input.Input{
Token: m.imageToken,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
}, p)
for range tokensPerGrid - 1 {
appendInput(&input.Input{
Token: m.imageToken,
}, p)
}
gridSpan := max(grid.Width/int(m.spatialMergeSize), grid.Height/int(m.spatialMergeSize))
p = p + int32(gridSpan)
appendInput(&input.Input{
Token: m.visionEnd,
}, p)
p++
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
positions := m.buildPositions(ctx, batch)
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if len(batch.Multimodal) > 0 {
hiddenStates = hiddenStates.Duplicate(ctx)
var deepstackVisualEmbeds []ml.Tensor
for _, mi := range batch.Multimodal {
visionOutputs := mi.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
}
for i, mm := range mi.Multimodal[1:] {
if deepstackVisualEmbeds[i] == nil {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
}
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
}
}
cache := m.Cache.(*HybridCache)
m.Options.masks = nil
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
if err != nil {
return nil, err
}
if i < len(deepstackVisualEmbeds) {
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
cache := m.Cache.(*HybridCache)
// Create masks once per forward pass
m.Options.masks = createMasks(ctx)
// Masks are allocated lazily only for chunked recurrent prefill.
m.Options.masks = nil
for i, layer := range m.Layers {
cache.SetLayer(i)
@@ -249,10 +438,17 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
if len(m.mropeSections) > 0 {
shift = shift.Repeat(ctx, 1, 4).Reshape(ctx, -1)
}
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
var (
_ model.Model = (*Model)(nil)
_ model.MultimodalProcessor = (*Model)(nil)
)
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
@@ -303,6 +499,22 @@ func New(c fs.Config) (model.Model, error) {
}
}
mropeSections := c.Ints("mrope_sections", nil)
if len(mropeSections) == 0 {
mropeSections = c.Ints("rope.mrope_section", nil)
}
if len(mropeSections) == 0 {
mropeSections = c.Ints("rope.dimension_sections", nil)
}
if len(mropeSections) > 4 {
mropeSections = mropeSections[:4]
}
ropeType := c.String("rope.scaling.type")
if ropeType == "" {
ropeType = c.String("rope.type")
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
@@ -318,7 +530,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeType: ropeType,
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
@@ -331,7 +543,16 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections {
if !yield(int(section)) {
return
}
}
}),
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
@@ -353,6 +574,19 @@ func New(c fs.Config) (model.Model, error) {
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
}
var vision *qwen3vl.VisionModel
var imageProcessor *qwen3vl.ImageProcessor
if c.Uint("vision.block_count", 0) > 0 {
vision = qwen3vl.NewVisionModel(c)
processor := qwen3vl.NewImageProcessor(c)
imageProcessor = &processor
}
spatialMergeSize := c.Uint("vision.spatial_merge_size", 2)
if spatialMergeSize == 0 {
spatialMergeSize = 2
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
@@ -371,8 +605,14 @@ func New(c fs.Config) (model.Model, error) {
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
Layers: layers,
Vision: vision,
ImageProcessor: imageProcessor,
Options: opts,
imageToken: int32(c.Uint("image_token_id", 151655)),
visionStart: int32(c.Uint("vision_start_token_id", 151652)),
visionEnd: int32(c.Uint("vision_end_token_id", 151653)),
spatialMergeSize: spatialMergeSize,
}
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
@@ -380,5 +620,7 @@ func New(c fs.Config) (model.Model, error) {
}
func init() {
model.Register("qwen35", New)
model.Register("qwen35moe", New)
model.Register("qwen3next", New)
}

View File

@@ -0,0 +1,101 @@
package qwen3next
import (
"testing"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/models/qwen3vl"
)
type fakeTensor struct {
*ggml.Tensor
dims []int
}
func (t *fakeTensor) Dim(i int) int {
return t.dims[i]
}
func makeImageInput(hash uint64, width, height, tokens int) *input.Input {
return &input.Input{
Multimodal: []input.Multimodal{{
Tensor: &fakeTensor{dims: []int{1, tokens, 1, 1}},
Data: &qwen3vl.Grid{Width: width, Height: height},
}},
MultimodalHash: hash,
}
}
func TestPostTokenizeMultiImageSpans(t *testing.T) {
m := &Model{
imageToken: 10,
visionStart: 11,
visionEnd: 12,
spatialMergeSize: 2,
}
inputs := []*input.Input{
{Token: 100},
makeImageInput(1, 8, 4, 4),
makeImageInput(2, 4, 8, 4),
{Token: 200},
}
got, err := m.PostTokenize(inputs)
if err != nil {
t.Fatalf("PostTokenize() error = %v", err)
}
want := []struct {
token int32
hash uint64
sameBatch int
hasMM bool
}{
{token: 100},
{token: 11, sameBatch: 5},
{token: 10, hash: 1, hasMM: true},
{token: 10},
{token: 10},
{token: 10},
{token: 12},
{token: 11, sameBatch: 5},
{token: 10, hash: 2, hasMM: true},
{token: 10},
{token: 10},
{token: 10},
{token: 12},
{token: 200},
}
if len(got) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i].Token != want[i].token {
t.Fatalf("got[%d].Token = %d, want %d", i, got[i].Token, want[i].token)
}
if got[i].MultimodalHash != want[i].hash {
t.Fatalf("got[%d].MultimodalHash = %d, want %d", i, got[i].MultimodalHash, want[i].hash)
}
if got[i].SameBatch != want[i].sameBatch {
t.Fatalf("got[%d].SameBatch = %d, want %d", i, got[i].SameBatch, want[i].sameBatch)
}
hasMM := len(got[i].Multimodal) > 0
if hasMM != want[i].hasMM {
t.Fatalf("got[%d].hasMM = %v, want %v", i, hasMM, want[i].hasMM)
}
}
wantPositions := []int32{0, 1, 2, 2, 2, 2, 6, 7, 8, 8, 8, 8, 12, 13}
if len(m.positionCache) != len(wantPositions) {
t.Fatalf("len(positionCache) = %d, want %d", len(m.positionCache), len(wantPositions))
}
for i := range wantPositions {
if m.positionCache[i] != wantPositions[i] {
t.Fatalf("positionCache[%d] = %d, want %d", i, m.positionCache[i], wantPositions[i])
}
}
}

View File

@@ -24,8 +24,8 @@ type ImageProcessor struct {
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
// NewImageProcessor creates a new image processor with default values.
func NewImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))

View File

@@ -56,60 +56,46 @@ var (
tokenVisionEnd int32 = 151653
)
type modelInput struct {
*input.Input
position int32
}
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
m.positionCache = m.positionCache[:0]
return slices.Collect(func(yield func(*input.Input) bool) {
for i := range inputs {
s := []modelInput{{Input: inputs[i]}}
if mm := inputs[i].Multimodal; mm != nil {
t := mm[0].Tensor
s = slices.Repeat([]modelInput{
{
position: int32(i + 1),
Input: &input.Input{Token: tokenVision},
},
}, t.Dim(1)+1+1)
var result []*input.Input
appendInput := func(inp *input.Input, position int32) {
result = append(result, inp)
m.positionCache = append(m.positionCache, position)
}
s[0] = modelInput{
Input: &input.Input{Token: tokenVisionStart},
position: int32(i),
}
s[len(s)-1] = modelInput{
Input: &input.Input{Token: tokenVisionEnd},
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
}
s[1] = modelInput{
Input: &input.Input{
Token: tokenVision,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1),
},
position: int32(i + 1),
}
}
for _, e := range s {
position := e.position
if position == 0 && len(m.positionCache) > 0 {
position = m.positionCache[len(m.positionCache)-1] + 1
}
m.positionCache = append(m.positionCache, position)
if !yield(e.Input) {
return
}
}
var p int32
for _, inp := range inputs {
if inp.Multimodal == nil {
appendInput(inp, p)
p++
continue
}
}), nil
grid := inp.Multimodal[0].Data.(*Grid)
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
appendInput(&input.Input{Token: tokenVisionStart}, p)
p++
appendInput(&input.Input{
Token: tokenVision,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: tokensPerGrid,
}, p)
for range tokensPerGrid - 1 {
appendInput(&input.Input{Token: tokenVision}, p)
}
p = p + int32(grid.Width/m.spatialMergeSize)
appendInput(&input.Input{Token: tokenVisionEnd}, p)
p++
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
@@ -143,9 +129,13 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
}
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
}
for i, mm := range mi.Multimodal[1:] {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
if deepstackVisualEmbeds[i] == nil {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
}
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
}
}
@@ -189,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
VisionModel: NewVisionModel(c),
ImageProcessor: NewImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {

View File

@@ -238,8 +238,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
return hiddenStates, deepstackStates
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
// NewVisionModel creates a new instance of the Qwen vision model.
func NewVisionModel(c fs.Config) *VisionModel {
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),

View File

@@ -49,6 +49,8 @@ func ParserForName(name string) Parser {
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3.5":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -59,6 +59,7 @@ func TestBuiltInParsersStillWork(t *testing.T) {
{"qwen3-coder"},
{"lfm2"},
{"lfm2-thinking"},
{"qwen3.5"},
{"harmony"},
}

View File

@@ -145,3 +145,26 @@ func TestQwen3ParserToolCall(t *testing.T) {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}
func TestQwen35ParserRespectsNoThink(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Hello! How can I help you today?", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Hello! How can I help you today?" {
t.Fatalf("expected content %q, got %q", "Hello! How can I help you today?", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}

View File

@@ -3,6 +3,7 @@ package renderers
import (
"bytes"
"encoding/json"
"fmt"
"sort"
"strings"
@@ -192,21 +193,25 @@ func lfm2RenderToolCalls(calls []api.ToolCall) string {
return sb.String()
}
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int) string {
content := lfm2RenderContent(message.Content, r.useImgTags)
if len(message.Images) == 0 {
return content
}
// chatPrompt may already have inserted [img] / [img-n] placeholders.
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
return content
}
var sb strings.Builder
placeholder := lfm2ImagePlaceholder(r.useImgTags)
for range message.Images {
sb.WriteString(placeholder)
if r.useImgTags {
for i := range message.Images {
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
}
} else {
placeholder := lfm2ImagePlaceholder(false)
if strings.Contains(content, placeholder) {
return content
}
for range message.Images {
sb.WriteString(placeholder)
}
}
sb.WriteString(content)
return sb.String()
@@ -262,6 +267,11 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
}
}
imageOffset := 0
for i := range startIdx {
imageOffset += len(messages[i].Images)
}
for i := startIdx; i < len(messages); i++ {
message := messages[i]
lastMessage := i == len(messages)-1
@@ -271,7 +281,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
sb.WriteString(message.Role)
sb.WriteString("\n")
content := r.renderMessageContent(message)
content := r.renderMessageContent(message, imageOffset)
imageOffset += len(message.Images)
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
content = strings.TrimSpace(content[idx+len("</think>"):])

View File

@@ -236,16 +236,6 @@ func TestLFM2Renderer_Images(t *testing.T) {
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_indexed_img_placeholder_not_duplicated",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "[img-0]Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{

View File

@@ -1,6 +1,7 @@
package renderers
import (
"fmt"
"strings"
"github.com/ollama/ollama/api"
@@ -9,10 +10,11 @@ import (
type Qwen3VLRenderer struct {
isThinking bool
useImgTags bool
emitEmptyThinkOnNoThink bool
useImgTags bool
}
func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
var subSb strings.Builder
for range content.Images {
@@ -20,7 +22,8 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
// model backends, and so we should eventually parameterize this or
// only output a placeholder such as [img]
if r.useImgTags {
subSb.WriteString("[img]")
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
imageOffset++
} else {
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
}
@@ -28,12 +31,17 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
// TODO: support videos
subSb.WriteString(content.Content)
return subSb.String()
return subSb.String(), imageOffset
}
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
var sb strings.Builder
isThinking := r.isThinking
if think != nil {
isThinking = think.Bool()
}
if len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
if len(messages) > 0 && messages[0].Role == "system" {
@@ -57,7 +65,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
message := messages[i]
if multiStepTool && message.Role == "user" {
// Check if content starts with <tool_response> and ends with </tool_response>
content := r.renderContent(message)
content, _ := r.renderContent(message, 0)
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
multiStepTool = false
lastQueryIndex = i
@@ -65,8 +73,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
}
}
imageOffset := 0
for i, message := range messages {
content := r.renderContent(message)
content, nextImageOffset := r.renderContent(message, imageOffset)
imageOffset = nextImageOffset
lastMessage := i == len(messages)-1
prefill := lastMessage && message.Role == "assistant"
@@ -76,13 +86,13 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
} else if message.Role == "assistant" {
contentReasoning := ""
if r.isThinking {
if isThinking {
if message.Thinking != "" {
contentReasoning = message.Thinking
}
}
if r.isThinking && i > lastQueryIndex {
if isThinking && i > lastQueryIndex {
if i == len(messages)-1 || contentReasoning != "" {
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
if content != "" {
@@ -125,8 +135,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
// prefill at the end
if lastMessage && !prefill {
sb.WriteString("<|im_start|>assistant\n")
if r.isThinking {
if isThinking {
sb.WriteString("<think>\n")
} else if r.emitEmptyThinkOnNoThink {
sb.WriteString("<think>\n\n</think>\n\n")
}
}
}

View File

@@ -101,7 +101,7 @@ Let me analyze this image.`,
},
useImgTags: true,
expected: `<|im_start|>user
[img]Describe this image.<|im_end|>
[img-0]Describe this image.<|im_end|>
<|im_start|>assistant
Let me analyze this image.`,
},
@@ -123,7 +123,7 @@ Let me analyze this image.`,
},
useImgTags: true,
expected: `<|im_start|>user
[img][img]Describe these images.<|im_end|>
[img-0][img-1]Describe these images.<|im_end|>
<|im_start|>assistant
Let me analyze this image.`,
},

View File

@@ -1,6 +1,7 @@
package renderers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -370,3 +371,74 @@ func TestFormatToolCallArgumentThinkingVL(t *testing.T) {
})
}
}
func TestQwen3VLRendererThinkOverride(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
renderThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderThinking, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected default thinking renderer to emit <think>, got:\n%s", renderThinking)
}
renderNonThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, &api.ThinkValue{Value: false})
if err != nil {
t.Fatal(err)
}
if strings.Contains(renderNonThinking, "<think>") {
t.Fatalf("expected think=false override to suppress <think>, got:\n%s", renderNonThinking)
}
renderForcedThinking, err := (&Qwen3VLRenderer{isThinking: false}).Render(msgs, nil, &api.ThinkValue{Value: true})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderForcedThinking, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected think=true override to emit <think>, got:\n%s", renderForcedThinking)
}
}
func TestQwen3VLRendererThinkOverrideWithExplicitNoThinkPrefill(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
renderNonThinking, err := (&Qwen3VLRenderer{
isThinking: true,
emitEmptyThinkOnNoThink: true,
}).Render(msgs, nil, &api.ThinkValue{Value: false})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderNonThinking, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected explicit think=false prefill block, got:\n%s", renderNonThinking)
}
}
func TestQwenRendererNameNoThinkBehaviorSplit(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
thinkFalse := &api.ThinkValue{Value: false}
qwen35Rendered, err := RenderWithRenderer("qwen3.5", msgs, nil, thinkFalse)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(qwen35Rendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected qwen3.5 renderer to emit explicit no-think prefill, got:\n%s", qwen35Rendered)
}
qwen3VLRendered, err := RenderWithRenderer("qwen3-vl-thinking", msgs, nil, thinkFalse)
if err != nil {
t.Fatal(err)
}
if strings.Contains(qwen3VLRendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected qwen3-vl-thinking renderer to keep legacy non-empty no-think behavior, got:\n%s", qwen3VLRendered)
}
}

View File

@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
case "qwen3-vl-thinking":
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
return renderer
case "qwen3.5":
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
return renderer
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
return renderer

View File

@@ -29,17 +29,27 @@ func TestRegisterCustomRenderer(t *testing.T) {
}
func TestBuiltInRendererStillWorks(t *testing.T) {
// Test that qwen3-coder still works
tests := []struct {
name string
}{
{name: "qwen3-coder"},
{name: "qwen3.5"},
}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Error("expected non-empty result from qwen3-coder renderer")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := RenderWithRenderer(tt.name, messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Fatalf("expected non-empty result from %s renderer", tt.name)
}
})
}
}

View File

@@ -86,6 +86,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
ID: len(images),
Data: i,
}
images = append(images, imgData)
if m.Config.Renderer != "" {
continue
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
@@ -93,8 +98,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = prefix + prompt
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestChatPrompt(t *testing.T) {
@@ -330,3 +331,38 @@ func TestChatPromptTokenizeCalls(t *testing.T) {
})
}
}
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "what do these photos have in common?",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
},
}
originalContent := msgs[0].Content
m := Model{
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if msgs[0].Content != originalContent {
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
}
if got, want := len(images), 3; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if prompt == "" {
t.Fatal("prompt is empty")
}
}

View File

@@ -6,6 +6,7 @@ import (
"log/slog"
"maps"
"os"
"slices"
"strings"
"unsafe"
@@ -33,6 +34,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
@@ -58,7 +62,7 @@ func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
switch {
// Full attention
case strings.HasSuffix(name, ".attn_q.weight"):
@@ -79,6 +83,10 @@ func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
// SSM
case strings.HasSuffix(name, ".ssm_ba.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_beta.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_alpha.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_out.weight"):
return fsggml.TensorTypeQ4_K, true
@@ -287,8 +295,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
newType := fsggml.TensorType(t.Kind)
if quantize {
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3nextQuantType(name); ok {
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3LinearAttnQuantType(name); ok {
return qt
}
}

View File

@@ -166,6 +166,60 @@ func TestGetTensorNewType(t *testing.T) {
}
}
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
cases := []struct {
name string
arch string
tensor string
fileType fsggml.FileType
expected fsggml.TensorType
}{
{
name: "qwen35_beta",
arch: "qwen35",
tensor: "blk.0.ssm_beta.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35_alpha",
arch: "qwen35",
tensor: "blk.0.ssm_alpha.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35moe_attn_qkv",
arch: "qwen35moe",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "non_qwen35_falls_back",
arch: "foo",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
kv := fsggml.KV{"general.architecture": tt.arch}
got := newType(&fsggml.Tensor{
Name: tt.tensor,
Shape: []uint64{256, 256},
Kind: uint32(fsggml.TensorTypeF16),
}, kv, &quantizeState{}, tt.fileType)
if got != tt.expected {
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
}
})
}
}
func TestQuantizeModel(t *testing.T) {
cases := []struct {
name string
@@ -173,6 +227,7 @@ func TestQuantizeModel(t *testing.T) {
tensors []*fsggml.Tensor
newType string
expectedTensorTypes map[string]fsggml.TensorType
expectErr bool
}{
{
name: "f16_q4_k",
@@ -253,6 +308,36 @@ func TestQuantizeModel(t *testing.T) {
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "f32_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
{
name: "f16_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
}
for _, tt := range cases {
@@ -264,6 +349,9 @@ func TestQuantizeModel(t *testing.T) {
}
defer fp.Close()
meta, err := fsggml.Decode(fp, -1)
if tt.expectErr && err != nil {
return
}
if err != nil {
t.Fatal(err.Error())
}
@@ -283,6 +371,12 @@ func TestQuantizeModel(t *testing.T) {
}
err = quantize(fp, tmp, meta, ftype, progress)
if tt.expectErr {
if err == nil {
t.Fatal("expected quantize to return an error")
}
return
}
if err != nil {
t.Fatalf("error during quantize: %s", err)
}

View File

@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

View File

@@ -296,19 +296,13 @@ func normalizeQuantType(quantize string) string {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
isStackedExpert := strings.Contains(name, ".mlp.experts.gate_up_proj") || strings.Contains(name, ".mlp.experts.down_proj")
// Use basic name-based check first
if !isStackedExpert && !ShouldQuantize(name, "") {
if !ShouldQuantize(name, "") {
return ""
}
// Quantize 2D linear tensors by default. qwen3.5 stacked expert tensors are
// also eligible even though they are stored as 3D [experts, out, in].
if !isStackedExpert && len(shape) != 2 {
return ""
}
if isStackedExpert && len(shape) != 3 {
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return ""
}
@@ -378,10 +372,9 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
}
// expertGroupRegexp matches expert tensor names and captures the group prefix.
// Matches nested and non-nested LLM prefixes and both per-expert ".weight"
// tensors and qwen3.5 stacked expert tensors without ".weight".
// Captures: model(.language_model(.model)?).layers.{L}.mlp.experts or .shared_experts
var expertGroupRegexp = regexp.MustCompile(`^(model(?:\.language_model(?:\.model)?)?\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*(?:\.weight)?$`)
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example:

View File

@@ -557,9 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// qwen3.5 stacked expert tensors are an exception: 3D [experts, out, in]
{"qwen3.5 stacked gate_up_proj", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int4", true},
{"qwen3.5 stacked down_proj", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int4", true},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -589,42 +586,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
}
}
func TestGetTensorQuantization_StackedExperts(t *testing.T) {
tests := []struct {
name string
shape []int32
quantize string
want string
}{
{
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
shape: []int32{256, 1024, 2048},
quantize: "int4",
want: "int4",
},
{
name: "model.language_model.layers.0.mlp.experts.down_proj",
shape: []int32{256, 2048, 512},
quantize: "int4",
want: "int8",
},
{
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
shape: []int32{256, 1024, 2050}, // not divisible by 32
quantize: "int4",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetTensorQuantization(tt.name, tt.shape, tt.quantize); got != tt.want {
t.Fatalf("GetTensorQuantization(%q, %v, %q) = %q, want %q", tt.name, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestExpertGroupPrefix(t *testing.T) {
tests := []struct {
name string
@@ -634,18 +595,10 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.gate_up_proj", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.layers.0.mlp.experts.down_proj", "model.language_model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.gate_up_proj", "model.language_model.model.layers.0.mlp.experts"},
{"model.language_model.model.layers.0.mlp.experts.down_proj", "model.language_model.model.layers.0.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
{"model.language_model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.layers.1.mlp.shared_experts"},
{"model.language_model.model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.model.layers.1.mlp.shared_experts"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts

View File

@@ -5,341 +5,62 @@ package mlxrunner
import (
"fmt"
"log/slog"
"os"
"strconv"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
const defaultPromptCacheBranches = 1
// HybridCacheEntry stores a single prompt branch with mixed cache types
// (e.g. KV + recurrent caches) and coordinates shared operations across them.
type HybridCacheEntry struct {
// CacheEntry stores a single sequence
type CacheEntry struct {
Tokens []int32
Caches []cache.Cache
}
// CacheEntry is kept as an alias for the current single-entry runner path.
// Future multi-entry cache stores should prefer HybridCacheEntry directly.
type CacheEntry = HybridCacheEntry
func promptCacheBranchLimit() int {
if v := os.Getenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultPromptCacheBranches
}
func cloneTokens(tokens []int32) []int32 {
out := make([]int32, len(tokens))
copy(out, tokens)
return out
}
func equalTokens(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func (r *Runner) cacheStore() []*HybridCacheEntry {
if len(r.caches) == 0 && r.cache != nil {
r.caches = []*HybridCacheEntry{r.cache}
}
if r.cache == nil && len(r.caches) > 0 {
r.cache = r.caches[0]
}
return r.caches
}
func (r *Runner) setCacheStore(entries []*HybridCacheEntry) {
r.caches = entries
if len(entries) == 0 {
r.cache = nil
return
}
r.cache = entries[0]
}
func (r *Runner) touchCacheEntry(idx int) {
if idx <= 0 || idx >= len(r.caches) {
return
}
e := r.caches[idx]
copy(r.caches[1:idx+1], r.caches[:idx])
r.caches[0] = e
r.cache = r.caches[0]
}
func (r *Runner) bestCacheEntry(tokens []int32) (idx int, prefix int) {
bestIdx, bestPrefix := -1, 0
for i, e := range r.cacheStore() {
if e == nil {
continue
}
p := e.PrefixLen(tokens)
if p > bestPrefix {
bestIdx, bestPrefix = i, p
}
}
return bestIdx, bestPrefix
}
func (e *HybridCacheEntry) PrefixLen(tokens []int32) int {
if e == nil {
return 0
}
prefix := 0
for prefix < len(tokens) && prefix < len(e.Tokens) && tokens[prefix] == e.Tokens[prefix] {
prefix++
}
return prefix
}
func (e *HybridCacheEntry) Free() {
if e == nil {
return
}
for _, c := range e.Caches {
if c != nil {
c.Free()
}
}
}
func (e *HybridCacheEntry) cachesSlice() []cache.Cache {
if e == nil {
return nil
}
return e.Caches
}
func (e *HybridCacheEntry) cachesCanTrim() bool {
if e == nil {
return false
}
for _, c := range e.Caches {
if c == nil {
continue
}
if !c.CanTrim() {
return false
}
}
return true
}
func (e *HybridCacheEntry) TrimToPrefix(prefix int) {
if e == nil {
return
}
for _, c := range e.Caches {
if c == nil || !c.CanTrim() {
continue
}
trim := c.Offset() - prefix
if trim > 0 {
c.Trim(trim)
}
}
if prefix < len(e.Tokens) {
e.Tokens = e.Tokens[:prefix]
}
}
func (e *HybridCacheEntry) RestoreToPrefix(target int) (int, bool) {
if e == nil {
return 0, false
}
restorePos := -1
sawNonTrimmable := false
for _, c := range e.Caches {
if c == nil || c.CanTrim() {
continue
}
sawNonTrimmable = true
restorer, ok := c.(cache.CheckpointRestorer)
if !ok {
return 0, false
}
pos, ok := restorer.BestCheckpoint(target)
if !ok {
return 0, false
}
if restorePos < 0 {
restorePos = pos
continue
}
if pos != restorePos {
return 0, false
}
}
if !sawNonTrimmable || restorePos < 0 {
return 0, false
}
e.TrimToPrefix(restorePos)
for _, c := range e.Caches {
if c == nil || c.CanTrim() {
continue
}
restorer, ok := c.(cache.CheckpointRestorer)
if !ok || !restorer.RestoreCheckpoint(restorePos) {
return 0, false
}
}
if restorePos < len(e.Tokens) {
e.Tokens = e.Tokens[:restorePos]
}
return restorePos, true
}
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
entries := r.cacheStore()
if len(entries) == 0 {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
branchLimit := promptCacheBranchLimit()
idx, prefix := r.bestCacheEntry(tokens)
if idx < 0 {
if branchLimit <= 1 && len(entries) == 1 && entries[0] != nil {
entries[0].Free()
r.setCacheStore(nil)
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
if idx > 0 {
r.touchCacheEntry(idx)
}
base := r.cache
if base == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
working := base
forked := false
if branchLimit > 1 && prefix > 0 {
working = base.Clone()
forked = true
// Find longest common prefix
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
prefix++
}
switch {
case prefix == 0:
if !forked && branchLimit <= 1 {
base.Free()
r.setCacheStore(nil)
for _, c := range r.cache.Caches {
c.Free()
}
r.cache = nil
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(working.Tokens):
if !working.cachesCanTrim() {
if restorePos, ok := working.RestoreToPrefix(prefix); ok {
slog.Info("Cache restore", "total", len(tokens), "matched", prefix, "restored", restorePos, "left", len(tokens[restorePos:]))
return working.cachesSlice(), tokens[restorePos:]
}
if forked {
working.Free()
} else if branchLimit <= 1 {
base.Free()
r.setCacheStore(nil)
}
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
working.TrimToPrefix(prefix)
r.cache.Tokens = r.cache.Tokens[:prefix]
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return working.cachesSlice(), tokens[prefix:]
return r.cache.Caches, tokens[prefix:]
}
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
entry := &HybridCacheEntry{
Tokens: cloneTokens(tokens),
Caches: caches,
}
branchLimit := promptCacheBranchLimit()
if branchLimit <= 1 {
r.setCacheStore([]*HybridCacheEntry{entry})
return
}
entries := r.cacheStore()
// Replace any exact-token duplicate branch with the new result.
for i := 0; i < len(entries); i++ {
if entries[i] == nil || !equalTokens(entries[i].Tokens, entry.Tokens) {
continue
}
entries[i].Free()
entries = append(entries[:i], entries[i+1:]...)
break
}
entries = append([]*HybridCacheEntry{entry}, entries...)
if len(entries) > branchLimit {
for _, evicted := range entries[branchLimit:] {
if evicted != nil {
evicted.Free()
}
}
entries = entries[:branchLimit]
}
r.setCacheStore(entries)
}
func (c *HybridCacheEntry) Clone() *HybridCacheEntry {
if c == nil {
return nil
}
tokens := make([]int32, len(c.Tokens))
copy(tokens, c.Tokens)
caches := make([]cache.Cache, len(c.Caches))
for i, cc := range c.Caches {
if cc != nil {
caches[i] = cc.Clone()
}
}
return &HybridCacheEntry{
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
}
}
func (c *HybridCacheEntry) LogCache() {
if c == nil || len(c.Caches) == 0 {
return
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
if kv == nil {
continue
}
k, v := kv.State()
if k == nil || v == nil {
continue
}
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))

View File

@@ -3,22 +3,13 @@
package cache
import (
"log/slog"
"os"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func kvCacheGrowDebugEnabled() bool {
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
}
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Materialize() []*mlx.Array
CanTrim() bool
Trim(int) int
Clone() Cache
Free()
@@ -26,19 +17,6 @@ type Cache interface {
Len() int
}
// CheckpointRecorder is an optional cache capability for recording recurrent
// state snapshots at specific token positions.
type CheckpointRecorder interface {
RecordCheckpoint(pos int)
}
// CheckpointRestorer is an optional cache capability for restoring recurrent
// state to a previously recorded checkpoint.
type CheckpointRestorer interface {
BestCheckpoint(target int) (pos int, ok bool)
RestoreCheckpoint(pos int) bool
}
type KVCache struct {
keys, values *mlx.Array
offset int
@@ -71,9 +49,6 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
if kvCacheGrowDebugEnabled() {
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
}
}
c.offset += L
@@ -92,19 +67,6 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.keys != nil && c.keys.Valid() {
out = append(out, c.keys)
}
if c.values != nil && c.values.Valid() {
out = append(out, c.values)
}
return out
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
@@ -228,8 +190,6 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
return c.keys, c.values
}
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n

View File

@@ -1,17 +0,0 @@
//go:build mlx
package cache
import "testing"
func TestKVCacheGrowDebugEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
if kvCacheGrowDebugEnabled() {
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
if !kvCacheGrowDebugEnabled() {
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
}
}

View File

@@ -1,519 +0,0 @@
//go:build mlx
package cache
import (
"os"
"strconv"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
const (
defaultRecurrentCheckpointCount = 32
defaultRecurrentCheckpointInterval = 128
defaultRecurrentCheckpointMinPos = 16
)
type recurrentCheckpoint struct {
pos int
convState *mlx.Array
deltaState *mlx.Array
}
type recurrentSlot struct {
convState *mlx.Array
deltaState *mlx.Array
checkpoints []recurrentCheckpoint
checkpointSize int
checkpointNext int
checkpointLastPos int
refs int
}
func getenvInt(name string, def int) int {
if v := os.Getenv(name); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return def
}
func recurrentCheckpointConfig() (count, interval, minPos int) {
count = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINTS", defaultRecurrentCheckpointCount)
interval = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", defaultRecurrentCheckpointInterval)
minPos = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", defaultRecurrentCheckpointMinPos)
if count < 0 {
count = 0
}
if interval < 0 {
interval = 0
}
if minPos < 0 {
minPos = 0
}
return count, interval, minPos
}
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
slot *recurrentSlot
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
checkpointCount int
checkpointInterval int
checkpointMinPos int
}
func newRecurrentSlot(checkpointCount int) *recurrentSlot {
s := &recurrentSlot{
refs: 1,
checkpointLastPos: -1,
}
if checkpointCount > 0 {
s.checkpoints = make([]recurrentCheckpoint, checkpointCount)
for i := range s.checkpoints {
s.checkpoints[i].pos = -1
}
}
return s
}
func retainRecurrentSlot(s *recurrentSlot) *recurrentSlot {
if s != nil {
s.refs++
}
return s
}
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
// Break dependency chains so recurrent state does not retain the full
// per-token compute graph over time.
snap := mlx.Snapshot(v)
mlx.Eval(snap)
old := *dst
*dst = snap
mlx.Pin(snap)
// Release previous cached state root, then recursively free the transient
// incoming graph root now that a detached snapshot is retained in cache.
if old != nil && old != snap {
mlx.Release(old)
}
if v != snap && v != old {
mlx.Release(v)
}
}
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
old := *dst
*dst = v
mlx.Pin(v)
if old != nil && old != v {
mlx.Release(old)
}
}
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := mlx.Detach(root)
old := *dst
*dst = detached
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Release(old)
}
// Intentionally do not force-release root/v here. In the fast path, the detached
// handle aliases the same MLX value and may still be lazily computed. Releasing the
// source handles can invalidate the cached state before the next eval/sweep point.
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Snapshot(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
count, interval, minPos := recurrentCheckpointConfig()
c := &RecurrentCache{
slot: newRecurrentSlot(count),
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
checkpointCount: count,
checkpointInterval: interval,
checkpointMinPos: minPos,
}
return c
}
func clonePinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
clone := a.Clone()
mlx.Pin(clone)
return clone
}
func releaseCheckpointEntry(e *recurrentCheckpoint) {
mlx.Release(e.convState, e.deltaState)
e.convState, e.deltaState = nil, nil
e.pos = -1
}
func releaseRecurrentSlot(s *recurrentSlot) {
if s == nil {
return
}
s.refs--
if s.refs > 0 {
return
}
mlx.Release(s.convState, s.deltaState)
s.convState, s.deltaState = nil, nil
for i := range s.checkpoints {
releaseCheckpointEntry(&s.checkpoints[i])
}
s.checkpointSize = 0
s.checkpointNext = 0
s.checkpointLastPos = -1
}
func cloneRecurrentSlot(src *recurrentSlot) *recurrentSlot {
if src == nil {
return nil
}
dst := &recurrentSlot{
checkpointSize: src.checkpointSize,
checkpointNext: src.checkpointNext,
checkpointLastPos: src.checkpointLastPos,
refs: 1,
}
if src.convState != nil && src.convState.Valid() {
dst.convState = snapshotPinned(src.convState)
}
if src.deltaState != nil && src.deltaState.Valid() {
dst.deltaState = snapshotPinned(src.deltaState)
}
if len(src.checkpoints) > 0 {
dst.checkpoints = make([]recurrentCheckpoint, len(src.checkpoints))
for i := range src.checkpoints {
dst.checkpoints[i].pos = src.checkpoints[i].pos
if src.checkpoints[i].pos < 0 {
continue
}
dst.checkpoints[i].convState = snapshotPinned(src.checkpoints[i].convState)
dst.checkpoints[i].deltaState = snapshotPinned(src.checkpoints[i].deltaState)
}
}
return dst
}
func (c *RecurrentCache) slotOrInit() *recurrentSlot {
if c.slot == nil {
c.slot = newRecurrentSlot(c.checkpointCount)
}
return c.slot
}
func (c *RecurrentCache) ensureWritableSlot() *recurrentSlot {
s := c.slotOrInit()
if s.refs <= 1 {
return s
}
c.slot = cloneRecurrentSlot(s)
s.refs--
return c.slot
}
func (c *RecurrentCache) pruneCheckpointsAfter(pos int) {
s := c.ensureWritableSlot()
if len(s.checkpoints) == 0 {
return
}
size := 0
next := -1
last := -1
minPos := int(^uint(0) >> 1)
minIdx := 0
for i := range s.checkpoints {
e := &s.checkpoints[i]
if e.pos > pos {
releaseCheckpointEntry(e)
}
if e.pos >= 0 {
size++
if e.pos > last {
last = e.pos
}
if e.pos < minPos {
minPos = e.pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.checkpointSize = size
s.checkpointLastPos = last
if size == 0 {
s.checkpointNext = 0
return
}
if next != -1 {
s.checkpointNext = next
return
}
s.checkpointNext = minIdx
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
s := c.slotOrInit()
needConv := s.convState == nil || s.convState.DType() != dtype ||
s.convState.Dim(0) != batch || s.convState.Dim(1) != c.convTail || s.convState.Dim(2) != c.convDim
needDelta := s.deltaState == nil || s.deltaState.DType() != dtype ||
s.deltaState.Dim(0) != batch || s.deltaState.Dim(1) != c.numVHeads || s.deltaState.Dim(2) != c.headVDim || s.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
s = c.ensureWritableSlot()
if needConv {
c.setStateRaw(&s.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.setStateRaw(&s.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.slotOrInit().convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateMaterialized(&s.convState, v)
}
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point. The conv-state input is usually a slice view, so request a
// compact contiguous copy to avoid pinning the whole source buffer.
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateDetached(&s.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.slotOrInit().deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateMaterialized(&s.deltaState, v)
}
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point.
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
s := c.ensureWritableSlot()
c.setStateDetached(&s.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
c.ensure(1, mlx.DTypeFloat32)
s := c.slotOrInit()
return s.convState, s.deltaState
}
func (c *RecurrentCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
s := c.slot
if s == nil {
return out
}
if s.convState != nil && s.convState.Valid() {
out = append(out, s.convState)
}
if s.deltaState != nil && s.deltaState.Valid() {
out = append(out, s.deltaState)
}
return out
}
func (c *RecurrentCache) RecordCheckpoint(pos int) {
s := c.slot
if s == nil || len(s.checkpoints) == 0 || pos <= 0 || pos < c.checkpointMinPos {
return
}
if c.offset != pos {
// Checkpoints are keyed by logical token position. Ignore callers with a
// mismatched position to avoid restoring inconsistent recurrent state.
return
}
if s.convState == nil || s.deltaState == nil || !s.convState.Valid() || !s.deltaState.Valid() {
return
}
if s.checkpointLastPos == pos {
return
}
if s.checkpointLastPos >= 0 && c.checkpointInterval > 0 && pos-s.checkpointLastPos < c.checkpointInterval {
return
}
if s.refs > 1 {
s = c.ensureWritableSlot()
}
idx := s.checkpointNext
e := &s.checkpoints[idx]
releaseCheckpointEntry(e)
e.pos = pos
e.convState = clonePinned(s.convState)
e.deltaState = clonePinned(s.deltaState)
s.checkpointNext = (idx + 1) % len(s.checkpoints)
if s.checkpointSize < len(s.checkpoints) {
s.checkpointSize++
}
s.checkpointLastPos = pos
}
func (c *RecurrentCache) BestCheckpoint(target int) (pos int, ok bool) {
s := c.slot
if s == nil {
return 0, false
}
best := -1
for i := range s.checkpoints {
pos := s.checkpoints[i].pos
if pos < 0 || pos > target {
continue
}
if pos > best {
best = pos
}
}
if best < 0 {
return 0, false
}
return best, true
}
func (c *RecurrentCache) RestoreCheckpoint(pos int) bool {
if pos < 0 {
return false
}
s := c.ensureWritableSlot()
for i := range s.checkpoints {
e := &s.checkpoints[i]
if e.pos != pos {
continue
}
if e.convState == nil || e.deltaState == nil || !e.convState.Valid() || !e.deltaState.Valid() {
return false
}
c.setStateRaw(&s.convState, e.convState.Clone())
c.setStateRaw(&s.deltaState, e.deltaState.Clone())
c.offset = pos
c.pruneCheckpointsAfter(pos)
return true
}
return false
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable; callers should use
// checkpoint-based restore instead.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
slot: retainRecurrentSlot(c.slotOrInit()),
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
checkpointCount: c.checkpointCount,
checkpointInterval: c.checkpointInterval,
checkpointMinPos: c.checkpointMinPos,
}
return clone
}
func (c *RecurrentCache) Free() {
releaseRecurrentSlot(c.slot)
c.slot = nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

View File

@@ -1,125 +0,0 @@
//go:build mlx
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func requireMLXRuntime(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX runtime unavailable: %v", err)
}
}
func newTestRecurrentCache(t *testing.T) *RecurrentCache {
t.Helper()
requireMLXRuntime(t)
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINTS", "2")
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", "0")
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", "0")
c := NewRecurrentCache(2, 3, 1, 2, 2)
_ = c.ConvState(1, mlx.DTypeFloat32)
_ = c.DeltaState(1, mlx.DTypeFloat32)
return c
}
func TestRecurrentCacheCloneSharesSlotUntilMutation(t *testing.T) {
c1 := newTestRecurrentCache(t)
t.Cleanup(func() {
c1.Free()
mlx.Sweep()
})
c1.Advance(8)
c1.RecordCheckpoint(8)
if got, ok := c1.BestCheckpoint(8); !ok || got != 8 {
t.Fatalf("BestCheckpoint(8) = (%d, %v), want (8, true)", got, ok)
}
c2 := c1.Clone().(*RecurrentCache)
t.Cleanup(func() {
c2.Free()
mlx.Sweep()
})
if c1.slot == nil || c2.slot == nil {
t.Fatal("expected non-nil shared slots")
}
if c1.slot != c2.slot {
t.Fatal("clone did not share recurrent slot")
}
if c1.slot.refs != 2 {
t.Fatalf("shared slot refs = %d, want 2", c1.slot.refs)
}
// Read access should not trigger a COW detach.
_ = c2.ConvState(1, mlx.DTypeFloat32)
_ = c2.DeltaState(1, mlx.DTypeFloat32)
if c1.slot != c2.slot {
t.Fatal("read access detached shared recurrent slot")
}
// Mutating recurrent state should detach and deep-copy checkpoint metadata.
c2.SetConvState(mlx.Zeros(mlx.DTypeFloat32, 1, 2, 3))
if c1.slot == c2.slot {
t.Fatal("SetConvState did not detach shared recurrent slot")
}
if c1.slot.refs != 1 || c2.slot.refs != 1 {
t.Fatalf("post-detach refs = (%d, %d), want (1, 1)", c1.slot.refs, c2.slot.refs)
}
if len(c1.slot.checkpoints) == 0 || len(c2.slot.checkpoints) == 0 {
t.Fatal("expected checkpoint ring to be preserved on detach")
}
if c1.slot.checkpoints[0].pos != c2.slot.checkpoints[0].pos {
t.Fatalf("checkpoint pos mismatch after detach: %d vs %d", c1.slot.checkpoints[0].pos, c2.slot.checkpoints[0].pos)
}
if c1.slot.checkpoints[0].pos != 8 {
t.Fatalf("checkpoint pos = %d, want 8", c1.slot.checkpoints[0].pos)
}
if c1.slot.checkpoints[0].convState == c2.slot.checkpoints[0].convState {
t.Fatal("checkpoint conv state was aliased after COW detach")
}
if c1.slot.checkpoints[0].deltaState == c2.slot.checkpoints[0].deltaState {
t.Fatal("checkpoint delta state was aliased after COW detach")
}
}
func TestRecurrentCacheFreeKeepsSharedCloneAlive(t *testing.T) {
c1 := newTestRecurrentCache(t)
c2 := c1.Clone().(*RecurrentCache)
t.Cleanup(func() {
c1.Free()
c2.Free()
mlx.Sweep()
})
if c2.slot == nil || c2.slot.refs != 2 {
t.Fatalf("shared clone refs = %d, want 2", func() int {
if c2.slot == nil {
return 0
}
return c2.slot.refs
}())
}
c1.Free()
if c2.slot == nil {
t.Fatal("clone slot was cleared after freeing sibling clone")
}
if c2.slot.refs != 1 {
t.Fatalf("clone slot refs after sibling Free = %d, want 1", c2.slot.refs)
}
if state := c2.ConvState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
t.Fatal("clone conv state invalid after freeing sibling clone")
}
if state := c2.DeltaState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
t.Fatal("clone delta state invalid after freeing sibling clone")
}
}

View File

@@ -1,339 +0,0 @@
//go:build mlx
package mlxrunner
import (
"reflect"
"testing"
cachepkg "github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type fakeCache struct {
canTrim bool
trims []int
freeCall int
offset int
}
func (f *fakeCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
func (f *fakeCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
func (f *fakeCache) Materialize() []*mlx.Array { return nil }
func (f *fakeCache) CanTrim() bool { return f.canTrim }
func (f *fakeCache) Trim(n int) int {
f.trims = append(f.trims, n)
f.offset -= n
return n
}
func (f *fakeCache) Clone() cachepkg.Cache { return &fakeCache{canTrim: f.canTrim, offset: f.offset} }
func (f *fakeCache) Free() { f.freeCall++ }
func (f *fakeCache) Offset() int { return f.offset }
func (f *fakeCache) Len() int { return f.offset }
type fakeCheckpointCache struct {
fakeCache
bestPos int
hasCheckpoint bool
restoreCalls []int
restoreSuccess bool
}
func (f *fakeCheckpointCache) BestCheckpoint(target int) (int, bool) {
if !f.hasCheckpoint || f.bestPos > target {
return 0, false
}
return f.bestPos, true
}
func (f *fakeCheckpointCache) RestoreCheckpoint(pos int) bool {
f.restoreCalls = append(f.restoreCalls, pos)
if !f.restoreSuccess || pos != f.bestPos {
return false
}
f.offset = pos
return true
}
func (f *fakeCheckpointCache) Clone() cachepkg.Cache {
clone := *f
clone.trims = nil
clone.restoreCalls = nil
return &clone
}
func TestFindNearestCacheReusesAppendOnlyNonTrimmableCache(t *testing.T) {
fc := &fakeCache{canTrim: false, offset: 2}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4})
if len(gotCaches) != 1 || gotCaches[0] != fc {
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
}
if want := []int32{3, 4}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 0 {
t.Fatalf("free calls = %d, want 0", fc.freeCall)
}
if len(fc.trims) != 0 {
t.Fatalf("trim calls = %v, want none", fc.trims)
}
}
func TestFindNearestCacheDropsNonTrimmableCacheOnDivergence(t *testing.T) {
fc := &fakeCache{canTrim: false, offset: 4}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if gotCaches != nil {
t.Fatalf("returned caches = %#v, want nil", gotCaches)
}
if want := []int32{1, 2, 9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 1 {
t.Fatalf("free calls = %d, want 1", fc.freeCall)
}
if len(fc.trims) != 0 {
t.Fatalf("trim calls = %v, want none", fc.trims)
}
if r.cache != nil {
t.Fatal("runner cache should be cleared on non-trimmable divergence")
}
}
func TestFindNearestCacheTrimsTrimmableCacheOnDivergence(t *testing.T) {
fc := &fakeCache{canTrim: true, offset: 4}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{fc},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if len(gotCaches) != 1 || gotCaches[0] != fc {
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
}
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if fc.freeCall != 0 {
t.Fatalf("free calls = %d, want 0", fc.freeCall)
}
if want := []int{2}; !reflect.DeepEqual(fc.trims, want) {
t.Fatalf("trim calls = %v, want %v", fc.trims, want)
}
if want := []int32{1, 2}; !reflect.DeepEqual(r.cache.Tokens, want) {
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
}
}
func TestFindNearestCacheRestoresCheckpointForNonTrimmableCaches(t *testing.T) {
kv := &fakeCache{canTrim: true, offset: 7}
rc1 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
rc2 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
Caches: []cachepkg.Cache{kv, rc1, rc2},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
if len(gotCaches) != 3 {
t.Fatalf("returned caches len = %d, want 3", len(gotCaches))
}
if want := []int32{8}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if want := []int{3}; !reflect.DeepEqual(kv.trims, want) {
t.Fatalf("kv trim calls = %v, want %v", kv.trims, want)
}
if want := []int{4}; !reflect.DeepEqual(rc1.restoreCalls, want) {
t.Fatalf("rc1 restore calls = %v, want %v", rc1.restoreCalls, want)
}
if want := []int{4}; !reflect.DeepEqual(rc2.restoreCalls, want) {
t.Fatalf("rc2 restore calls = %v, want %v", rc2.restoreCalls, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.cache.Tokens, want) {
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
}
}
func TestFindNearestCacheDropsOnMismatchedCheckpointRestorePoints(t *testing.T) {
rc1 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 4,
hasCheckpoint: true,
restoreSuccess: true,
}
rc2 := &fakeCheckpointCache{
fakeCache: fakeCache{canTrim: false, offset: 7},
bestPos: 3,
hasCheckpoint: true,
restoreSuccess: true,
}
r := &Runner{
cache: &CacheEntry{
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
Caches: []cachepkg.Cache{rc1, rc2},
},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
if gotCaches != nil {
t.Fatalf("returned caches = %#v, want nil", gotCaches)
}
if want := []int32{1, 2, 3, 4, 8}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if rc1.freeCall != 1 || rc2.freeCall != 1 {
t.Fatalf("free calls = (%d,%d), want (1,1)", rc1.freeCall, rc2.freeCall)
}
if len(rc1.restoreCalls) != 0 || len(rc2.restoreCalls) != 0 {
t.Fatalf("restore calls = (%v,%v), want none", rc1.restoreCalls, rc2.restoreCalls)
}
}
func TestFindNearestCacheSelectsBestPrefixAcrossBranches(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "4")
short := &fakeCache{canTrim: true, offset: 2}
long := &fakeCache{canTrim: true, offset: 4}
shortEntry := &HybridCacheEntry{
Tokens: []int32{1, 2},
Caches: []cachepkg.Cache{short},
}
longEntry := &HybridCacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{long},
}
r := &Runner{
cache: shortEntry,
caches: []*HybridCacheEntry{shortEntry, longEntry},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 9})
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if len(gotCaches) != 1 {
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
}
if gotCaches[0] == long {
t.Fatal("expected cloned cache in multi-branch mode, got original branch cache")
}
if r.cache != longEntry || r.caches[0] != longEntry {
t.Fatal("best branch was not promoted to front of cache store")
}
}
func TestFindNearestCacheForksBranchWithCloneWhenMultiBranchEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
base := &fakeCache{canTrim: true, offset: 4}
baseEntry := &HybridCacheEntry{
Tokens: []int32{1, 2, 3, 4},
Caches: []cachepkg.Cache{base},
}
r := &Runner{
cache: baseEntry,
caches: []*HybridCacheEntry{baseEntry},
}
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
}
if len(gotCaches) != 1 {
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
}
clone, ok := gotCaches[0].(*fakeCache)
if !ok {
t.Fatalf("returned cache type = %T, want *fakeCache", gotCaches[0])
}
if clone == base {
t.Fatal("expected branch fork to return a cloned cache")
}
if len(base.trims) != 0 {
t.Fatalf("base branch trim calls = %v, want none", base.trims)
}
if want := []int{2}; !reflect.DeepEqual(clone.trims, want) {
t.Fatalf("forked branch trim calls = %v, want %v", clone.trims, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(baseEntry.Tokens, want) {
t.Fatalf("base entry tokens = %v, want %v", baseEntry.Tokens, want)
}
r.InsertCache([]int32{1, 2, 9}, gotCaches)
if len(r.caches) != 2 {
t.Fatalf("cache store len = %d, want 2", len(r.caches))
}
if want := []int32{1, 2, 9}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
t.Fatalf("new branch tokens = %v, want %v", r.caches[0].Tokens, want)
}
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
t.Fatalf("preserved branch tokens = %v, want %v", r.caches[1].Tokens, want)
}
}
func TestInsertCacheEvictsOldestBranchWhenStoreFull(t *testing.T) {
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
f1 := &fakeCache{canTrim: true, offset: 1}
f2 := &fakeCache{canTrim: true, offset: 2}
f3 := &fakeCache{canTrim: true, offset: 3}
r := &Runner{}
r.InsertCache([]int32{1}, []cachepkg.Cache{f1})
r.InsertCache([]int32{1, 2}, []cachepkg.Cache{f2})
r.InsertCache([]int32{1, 2, 3}, []cachepkg.Cache{f3})
if len(r.caches) != 2 {
t.Fatalf("cache store len = %d, want 2", len(r.caches))
}
if f1.freeCall != 1 {
t.Fatalf("oldest branch free calls = %d, want 1", f1.freeCall)
}
if f2.freeCall != 0 || f3.freeCall != 0 {
t.Fatalf("unexpected frees for retained branches: f2=%d f3=%d", f2.freeCall, f3.freeCall)
}
if want := []int32{1, 2, 3}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
t.Fatalf("MRU tokens = %v, want %v", r.caches[0].Tokens, want)
}
if want := []int32{1, 2}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
t.Fatalf("LRU tokens = %v, want %v", r.caches[1].Tokens, want)
}
}

View File

@@ -7,6 +7,4 @@ import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
)

View File

@@ -267,20 +267,3 @@ func LogArrays() {
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
}
// Release forcibly frees arrays regardless of reference accounting.
// Use only for arrays that are known to be unreachable by any live model state.
func Release(s ...*Array) (n int) {
seen := make(map[*Array]bool, len(s))
for _, t := range s {
if t == nil || !t.Valid() || seen[t] {
continue
}
seen[t] = true
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.pinned = false
t.ctx.ctx = nil
}
return n
}

View File

@@ -1,275 +0,0 @@
//go:build mlx
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"sync/atomic"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled atomic.Bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}

View File

@@ -19,8 +19,7 @@ func doEval(outputs []*Array, async bool) {
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
// Callers may pass optional tensors (e.g. debug-only logprobs) as nil.
if output != nil && output.Valid() {
if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}

View File

@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
return out
}
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("")
sinks := New("")
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -429,32 +378,6 @@ func Collect(v any) []*Array {
return arrays
}
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
func Snapshot(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("SNAPSHOT")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// CollectReachable collects arrays from v and all transitive graph inputs.
func CollectReachable(v any) []*Array {
return Collect(v)
}
// Detach returns a new Array handle that shares the same MLX value but does
// not retain Go-side graph input references.
func Detach(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("DETACH")
C.mlx_array_set(&out.ctx, a.ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return

View File

@@ -6,10 +6,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"os"
"strconv"
"time"
"github.com/ollama/ollama/logutil"
@@ -17,234 +14,6 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
const defaultRecurrentMaterializeInterval = 64
const defaultPipelineTimingEvery = 64
func prefillChunkSize(lowMemoryDecode bool) int {
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
if lowMemoryDecode {
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
return 32
}
return 2 << 10
}
func hasRecurrentCaches(caches []cache.Cache) bool {
for _, c := range caches {
if c == nil {
continue
}
if _, ok := c.(*cache.RecurrentCache); ok {
return true
}
}
return false
}
// recurrentMaterializeInterval controls periodic recurrent-cache materialization
// during async decode. It exists to bound graph/handle growth when using fast
// recurrent cache writes; it is primarily a memory/stability tuning knob, not a
// throughput knob.
func recurrentMaterializeInterval(lowMemoryDecode bool, hasRecurrent bool) int {
if lowMemoryDecode || !hasRecurrent {
return 0
}
if v := os.Getenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
if n < 0 {
return 0
}
return n
}
}
return defaultRecurrentMaterializeInterval
}
func mlxDebugMemoryEnabled() bool {
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
}
// mlxPipelineTimingConfig controls runner-side decode pipeline timing logs. This
// is diagnostic-only and intentionally separate from model-specific timing.
func mlxPipelineTimingConfig() (enabled bool, every int) {
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_TIMING"); ok {
if parsed, err := strconv.ParseBool(v); err == nil {
enabled = parsed
}
}
if !enabled {
return false, 0
}
every = defaultPipelineTimingEvery
if v := os.Getenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
every = n
}
}
return true, every
}
// mlxComputeLogprobsEnabled restores the old decode-step logprob normalization
// path for profiling/experiments. It is off by default because the MLX runner
// does not currently populate Response.Logprobs.
func mlxComputeLogprobsEnabled() bool {
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS"); ok {
if enabled, err := strconv.ParseBool(v); err == nil {
return enabled
}
}
// The MLX runner currently does not populate Response.Logprobs, so skip the
// full-vocab logprob normalization path unless explicitly requested for
// debugging/experiments.
return false
}
type pipelineTiming struct {
every int
stepCalls int
stepAsync int
stepSync int
sampleInts int
stepTotalDur time.Duration
forwardDur time.Duration
unembedDur time.Duration
sliceDur time.Duration
logprobsDur time.Duration
sampleDur time.Duration
pinSweepDur time.Duration
asyncEvalDur time.Duration
sampleIntDur time.Duration
lastEmitCount int
}
func newPipelineTiming() *pipelineTiming {
enabled, every := mlxPipelineTimingConfig()
if !enabled {
return nil
}
pt := &pipelineTiming{every: every}
fmt.Fprintf(os.Stderr, "mlx pipeline timing: enabled every=%d\n", every)
return pt
}
func (pt *pipelineTiming) recordStep(
async bool,
total, forward, unembed, slice, logprobs, sample, pinSweep, asyncEval time.Duration,
) {
if pt == nil {
return
}
pt.stepCalls++
if async {
pt.stepAsync++
} else {
pt.stepSync++
}
pt.stepTotalDur += total
pt.forwardDur += forward
pt.unembedDur += unembed
pt.sliceDur += slice
pt.logprobsDur += logprobs
pt.sampleDur += sample
pt.pinSweepDur += pinSweep
pt.asyncEvalDur += asyncEval
}
func (pt *pipelineTiming) recordSampleInt(d time.Duration, decodeCount int) {
if pt == nil {
return
}
pt.sampleInts++
pt.sampleIntDur += d
pt.maybeEmit(false, decodeCount)
}
func (pt *pipelineTiming) maybeEmit(force bool, decodeCount int) {
if pt == nil {
return
}
if !force {
if pt.every <= 0 || decodeCount <= 0 || decodeCount%pt.every != 0 {
return
}
}
if pt.lastEmitCount == decodeCount {
return
}
pt.lastEmitCount = decodeCount
msAvg := func(d time.Duration, n int) float64 {
if n <= 0 {
return 0
}
return float64(d) / float64(n) / float64(time.Millisecond)
}
stepResidual := pt.stepTotalDur - pt.forwardDur - pt.unembedDur - pt.sliceDur - pt.logprobsDur - pt.sampleDur - pt.pinSweepDur - pt.asyncEvalDur
if stepResidual < 0 {
stepResidual = 0
}
fmt.Fprintf(
os.Stderr,
"mlx pipeline timing: decode=%d step_calls=%d step_async=%d step_sync=%d avg_step_ms=%.2f fwd_ms=%.2f unembed_ms=%.2f slice_ms=%.2f logprobs_ms=%.2f sample_ms=%.2f pin_sweep_ms=%.2f async_eval_ms=%.2f step_residual_ms=%.2f sample_int_ms=%.2f\n",
decodeCount,
pt.stepCalls,
pt.stepAsync,
pt.stepSync,
msAvg(pt.stepTotalDur, pt.stepCalls),
msAvg(pt.forwardDur, pt.stepCalls),
msAvg(pt.unembedDur, pt.stepCalls),
msAvg(pt.sliceDur, pt.stepCalls),
msAvg(pt.logprobsDur, pt.stepCalls),
msAvg(pt.sampleDur, pt.stepCalls),
msAvg(pt.pinSweepDur, pt.stepCalls),
msAvg(pt.asyncEvalDur, pt.stepCalls),
msAvg(stepResidual, pt.stepCalls),
msAvg(pt.sampleIntDur, pt.sampleInts),
)
}
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
if usePromptCache {
insertCache()
logMemory("request_done_cached", -1)
return
}
freeCaches()
logMemory("request_done_freed", -1)
}
func recordCacheCheckpoints(caches []cache.Cache, pos int) {
if pos <= 0 {
return
}
for _, c := range caches {
if c == nil {
continue
}
if recorder, ok := c.(cache.CheckpointRecorder); ok {
recorder.RecordCheckpoint(pos)
}
}
}
func freeOwnedCaches(caches []cache.Cache) {
for i, c := range caches {
if c == nil {
continue
}
c.Free()
caches[i] = nil
}
}
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
@@ -262,24 +31,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
inputs := r.Tokenizer.Encode(request.Prompt, true)
usePromptCache := true
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
usePromptCache = false
}
lowMemoryDecode := !usePromptCache
if m, ok := r.Model.(interface{ LowMemoryDecode() bool }); ok {
lowMemoryDecode = m.LowMemoryDecode()
}
prefillChunk := prefillChunkSize(lowMemoryDecode)
var caches []cache.Cache
var tokens []int32
if usePromptCache {
caches, tokens = r.FindNearestCache(inputs)
} else {
tokens = inputs
}
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
@@ -291,140 +43,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
materializeRecurrentCaches := func() bool {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
if c == nil {
continue
}
if _, ok := c.(*cache.RecurrentCache); !ok {
continue
}
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return false
}
mlx.Eval(state...)
return true
}
freeCaches := func() {
// Non-prompt-cache requests allocate fresh caches every generation.
// Explicitly free cache-owned state (including recurrent checkpoints),
// then sweep remaining intermediates.
freeOwnedCaches(caches)
mlx.Sweep()
mlx.ClearCache()
}
debugMemory := mlxDebugMemoryEnabled()
hasRecurrent := hasRecurrentCaches(caches)
asyncRecurrentMaterializeEvery := recurrentMaterializeInterval(lowMemoryDecode, hasRecurrent)
computeStepLogprobs := mlxComputeLogprobsEnabled()
pipelineTiming := newPipelineTiming()
logMemory := func(phase string, token int) {
if !debugMemory {
return
}
if token >= 0 {
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
return
}
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
}
logMemory("prefill_start", -1)
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(prefillChunk, total-processed-1)
n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
materializeCaches()
recordCacheCheckpoints(caches, processed+n)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
logMemory("prefill_done", -1)
step := func(token *mlx.Array, async bool) (*mlx.Array, *mlx.Array) {
var t0, t time.Time
var forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur time.Duration
if pipelineTiming != nil {
t0 = time.Now()
t = t0
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
if pipelineTiming != nil {
forwardDur = time.Since(t)
t = time.Now()
}
logits := r.Model.Unembed(fwd)
if pipelineTiming != nil {
unembedDur = time.Since(t)
t = time.Now()
}
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
if pipelineTiming != nil {
sliceDur = time.Since(t)
t = time.Now()
}
var logprobs *mlx.Array
sampleInput := logits
if computeStepLogprobs {
logprobs = logits.Subtract(logits.Logsumexp(true))
sampleInput = logprobs
}
if pipelineTiming != nil {
logprobsDur = time.Since(t)
t = time.Now()
}
sample := request.Sample(sampleInput)
if pipelineTiming != nil {
sampleDur = time.Since(t)
t = time.Now()
}
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
if pipelineTiming != nil {
pinSweepDur = time.Since(t)
}
if async {
mlx.AsyncEval(sample, logprobs)
if pipelineTiming != nil {
asyncEvalDur = time.Since(t)
}
}
if pipelineTiming != nil {
pipelineTiming.recordStep(async, time.Since(t0), forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur)
}
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed), !lowMemoryDecode)
if lowMemoryDecode {
// Materialize cache updates to prevent transform graph growth.
materializeCaches()
}
recordCacheCheckpoints(caches, total)
logMemory("decode_init", -1)
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
@@ -432,34 +84,20 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
var nextSample, nextLogprobs *mlx.Array
if !lowMemoryDecode {
nextSample, nextLogprobs = step(sample, true)
}
nextSample, nextLogprobs := step(sample)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
logMemory("decode_first_eval", i)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
var intWaitStart time.Time
if pipelineTiming != nil {
intWaitStart = time.Now()
}
output := int32(sample.Int())
if pipelineTiming != nil {
pipelineTiming.recordSampleInt(time.Since(intWaitStart), len(outputs)+1)
}
outputs = append(outputs, output)
if !lowMemoryDecode {
recordCacheCheckpoints(caches, total+len(outputs))
}
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
mlx.Unpin(sample, logprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -471,53 +109,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
Token: int(output),
}
// For recurrent linear-attention models, avoid async prefetch to reduce
// peak memory and clear allocator cache every token.
if lowMemoryDecode {
mlx.Unpin(sample, logprobs)
mlx.Sweep()
if i+1 >= request.Options.MaxTokens {
break
}
mlx.ClearCache()
sample, logprobs = step(mlx.FromValues([]int32{output}, 1), false)
// Materialize cache updates to avoid unbounded transform chains.
materializeCaches()
recordCacheCheckpoints(caches, total+len(outputs))
if i%32 == 0 {
logMemory("decode_lowmem_step", i)
}
continue
}
mlx.Unpin(sample, logprobs)
if asyncRecurrentMaterializeEvery > 0 && (i+1)%asyncRecurrentMaterializeEvery == 0 {
if materializeRecurrentCaches() {
mlx.Sweep()
logMemory("decode_async_recurrent_materialize", i)
}
}
if i%256 == 0 {
mlx.ClearCache()
}
if i%64 == 0 {
logMemory("decode_async_step", i)
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Unpin(sample, logprobs)
if pipelineTiming != nil {
pipelineTiming.maybeEmit(true, len(outputs))
}
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
finalizeRequestCaches(usePromptCache,
func() { r.InsertCache(append(inputs, outputs...), caches) },
freeCaches,
logMemory,
)
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
@@ -526,6 +129,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
r.cache.LogCache()
}
}
return nil
}

View File

@@ -1,209 +0,0 @@
//go:build mlx
package mlxrunner
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type stubCache struct {
freeCalls int
}
func (s *stubCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
func (s *stubCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
func (s *stubCache) Materialize() []*mlx.Array { return nil }
func (s *stubCache) CanTrim() bool { return true }
func (s *stubCache) Trim(int) int { return 0 }
func (s *stubCache) Clone() cache.Cache { return s }
func (s *stubCache) Free() { s.freeCalls++ }
func (s *stubCache) Offset() int { return 0 }
func (s *stubCache) Len() int { return 0 }
func TestPrefillChunkSize(t *testing.T) {
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
if got := prefillChunkSize(false); got != 2<<10 {
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
}
if got := prefillChunkSize(true); got != 32 {
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
}
}
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
if got := prefillChunkSize(false); got != 96 {
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
}
if got := prefillChunkSize(true); got != 96 {
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
}
}
func TestMLXDebugMemoryEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
if mlxDebugMemoryEnabled() {
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
if !mlxDebugMemoryEnabled() {
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
}
}
func TestHasRecurrentCaches(t *testing.T) {
if hasRecurrentCaches(nil) {
t.Fatal("hasRecurrentCaches(nil) = true, want false")
}
if hasRecurrentCaches([]cache.Cache{cache.NewKVCache()}) {
t.Fatal("hasRecurrentCaches(kv-only) = true, want false")
}
rc := cache.NewRecurrentCache(4, 8, 2, 16, 8)
if !hasRecurrentCaches([]cache.Cache{cache.NewKVCache(), rc}) {
t.Fatal("hasRecurrentCaches(mixed) = false, want true")
}
}
func TestRecurrentMaterializeInterval(t *testing.T) {
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "")
if got := recurrentMaterializeInterval(true, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(lowmem=true, recurrent=true) = %d, want 0", got)
}
if got := recurrentMaterializeInterval(false, false); got != 0 {
t.Fatalf("recurrentMaterializeInterval(lowmem=false, recurrent=false) = %d, want 0", got)
}
if got := recurrentMaterializeInterval(false, true); got != defaultRecurrentMaterializeInterval {
t.Fatalf("recurrentMaterializeInterval(default) = %d, want %d", got, defaultRecurrentMaterializeInterval)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "16")
if got := recurrentMaterializeInterval(false, true); got != 16 {
t.Fatalf("recurrentMaterializeInterval(env=16) = %d, want 16", got)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "0")
if got := recurrentMaterializeInterval(false, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(env=0) = %d, want 0", got)
}
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "-1")
if got := recurrentMaterializeInterval(false, true); got != 0 {
t.Fatalf("recurrentMaterializeInterval(env=-1) = %d, want 0", got)
}
}
func TestMLXPipelineTimingConfig(t *testing.T) {
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "")
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "")
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
t.Fatalf("mlxPipelineTimingConfig() = (%v, %d), want (false, 0)", enabled, every)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "1")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
t.Fatalf("mlxPipelineTimingConfig(enabled default) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "16")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != 16 {
t.Fatalf("mlxPipelineTimingConfig(enabled env=16) = (%v, %d), want (true, 16)", enabled, every)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "0")
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
t.Fatalf("mlxPipelineTimingConfig(enabled env=0) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
}
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "0")
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
t.Fatalf("mlxPipelineTimingConfig(disabled) = (%v, %d), want (false, 0)", enabled, every)
}
}
func TestMLXComputeLogprobsEnabled(t *testing.T) {
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "")
if mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = true, want false")
}
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "1")
if !mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = false with env=1, want true")
}
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "0")
if mlxComputeLogprobsEnabled() {
t.Fatal("mlxComputeLogprobsEnabled() = true with env=0, want false")
}
}
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
insertCalls := 0
freeCalls := 0
logPhase := ""
finalizeRequestCaches(
true,
func() { insertCalls++ },
func() { freeCalls++ },
func(phase string, _ int) { logPhase = phase },
)
if insertCalls != 1 {
t.Fatalf("insert calls = %d, want 1", insertCalls)
}
if freeCalls != 0 {
t.Fatalf("free calls = %d, want 0", freeCalls)
}
if logPhase != "request_done_cached" {
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
}
}
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
insertCalls := 0
freeCalls := 0
logPhase := ""
finalizeRequestCaches(
false,
func() { insertCalls++ },
func() { freeCalls++ },
func(phase string, _ int) { logPhase = phase },
)
if insertCalls != 0 {
t.Fatalf("insert calls = %d, want 0", insertCalls)
}
if freeCalls != 1 {
t.Fatalf("free calls = %d, want 1", freeCalls)
}
if logPhase != "request_done_freed" {
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
}
}
func TestFreeOwnedCaches(t *testing.T) {
a := &stubCache{}
b := &stubCache{}
caches := []cache.Cache{a, nil, b}
freeOwnedCaches(caches)
if a.freeCalls != 1 {
t.Fatalf("a free calls = %d, want 1", a.freeCalls)
}
if b.freeCalls != 1 {
t.Fatalf("b free calls = %d, want 1", b.freeCalls)
}
if caches[0] != nil || caches[2] != nil {
t.Fatalf("cache entries not nilled after free: %#v", caches)
}
}

View File

@@ -62,39 +62,6 @@ type Runner struct {
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
caches []*HybridCacheEntry
}
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
if len(tensors) == 0 {
return 0, 0
}
seen := make(map[*mlx.Array]bool, len(tensors))
toRelease := make([]*mlx.Array, 0, len(tensors))
for name, arr := range tensors {
if arr == nil || !arr.Valid() {
delete(tensors, name)
continue
}
if keep != nil {
if _, ok := keep[arr]; ok {
continue
}
}
delete(tensors, name)
if seen[arr] {
continue
}
seen[arr] = true
toRelease = append(toRelease, arr)
}
if len(toRelease) == 0 {
return 0, 0
}
return len(toRelease), mlx.Release(toRelease...)
}
func (r *Runner) Load(modelName string) error {
@@ -118,33 +85,9 @@ func (r *Runner) Load(modelName string) error {
// Assign weights to model (model-specific logic)
loadWeights := base.Weights(m)
if err := loadWeights(tensors); err != nil {
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
}
mlx.Sweep()
mlx.ClearCache()
return err
}
// Materialize model-owned roots before releasing source tensor handles, then
// pin only those roots. This avoids retaining large load-time intermediates
// while still protecting shared model tensors from Sweep.
roots := mlx.Collect(m)
mlx.Eval(roots...)
mlx.Pin(roots...)
keep := make(map[*mlx.Array]struct{})
for _, arr := range roots {
if arr != nil && arr.Valid() {
keep[arr] = struct{}{}
}
}
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
}
mlx.Sweep()
mlx.ClearCache()
r.Model = m
r.Tokenizer = m.Tokenizer()
return nil

View File

@@ -15,40 +15,6 @@ type LinearLayer interface {
OutputDim() int32
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,206 +0,0 @@
//go:build mlx
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestModelRuntimeToggles(t *testing.T) {
m := &Model{}
if m.DisablePromptCache() {
t.Fatal("DisablePromptCache() = true, want false")
}
if m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = true, want false")
}
if !m.EnableCompile() {
t.Fatal("EnableCompile() = false, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "0")
if m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "1")
if !m.LowMemoryDecode() {
t.Fatal("LowMemoryDecode() = false with env override 1, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "0")
if m.EnableCompile() {
t.Fatal("EnableCompile() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "1")
if !m.EnableCompile() {
t.Fatal("EnableCompile() = false with env override, want true")
}
if !qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = false, want true")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "0")
if qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = true with env override 0, want false")
}
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "1")
if !qwen35FastRecurrentWrite() {
t.Fatal("qwen35FastRecurrentWrite() = false with env override 1, want true")
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}

View File

@@ -1,16 +0,0 @@
//go:build mlx
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}