mirror of
https://github.com/ollama/ollama.git
synced 2026-02-25 03:26:46 -05:00
Compare commits
2 Commits
pdevine/me
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da70c3222e | ||
|
|
9d902d63ce |
@@ -320,7 +320,7 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &lfm2Model{}
|
||||
case "Lfm2VlForConditionalGeneration":
|
||||
conv = &lfm2VLTextModel{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
@@ -13,8 +14,21 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
type qwen3NextRopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
type qwen3NextRopeParams struct {
|
||||
MRopeInterleaved bool `json:"mrope_interleaved"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
}
|
||||
|
||||
type qwen3NextTextConfig struct {
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
NormTopkProb *bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
@@ -43,16 +58,102 @@ type qwen3NextModel struct {
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
|
||||
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
type qwen3NextVisionConfig struct {
|
||||
Depth uint32 `json:"depth"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
InChannels uint32 `json:"in_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
|
||||
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
qwen3NextTextConfig
|
||||
|
||||
TextConfig *qwen3NextTextConfig `json:"text_config"`
|
||||
VisionModel qwen3NextVisionConfig `json:"vision_config"`
|
||||
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VisionStartTokenID uint32 `json:"vision_start_token_id"`
|
||||
VisionEndTokenID uint32 `json:"vision_end_token_id"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
|
||||
if q.TextConfig != nil {
|
||||
q.qwen3NextTextConfig = *q.TextConfig
|
||||
}
|
||||
|
||||
if q.RopeTheta == 0 {
|
||||
q.RopeTheta = q.RopeParameters.RopeTheta
|
||||
}
|
||||
if q.PartialRotaryFactor == 0 {
|
||||
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
|
||||
}
|
||||
|
||||
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
|
||||
q.RopeScaling.Type = q.RopeParameters.RopeType
|
||||
}
|
||||
|
||||
// Pull vision preprocessing fields when present.
|
||||
if q.VisionModel.Depth > 0 {
|
||||
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
|
||||
var pre struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
if json.Unmarshal(bts, &pre) == nil {
|
||||
if q.VisionModel.PatchSize == 0 {
|
||||
q.VisionModel.PatchSize = pre.PatchSize
|
||||
}
|
||||
if q.VisionModel.TemporalPatchSize == 0 {
|
||||
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
|
||||
}
|
||||
if q.VisionModel.SpatialMergeSize == 0 {
|
||||
q.VisionModel.SpatialMergeSize = pre.MergeSize
|
||||
}
|
||||
if q.VisionModel.Size.ShortestEdge == 0 {
|
||||
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||
}
|
||||
if q.VisionModel.Size.LongestEdge == 0 {
|
||||
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
|
||||
}
|
||||
if len(q.VisionModel.ImageMean) == 0 {
|
||||
q.VisionModel.ImageMean = pre.ImageMean
|
||||
}
|
||||
if len(q.VisionModel.ImageStd) == 0 {
|
||||
q.VisionModel.ImageStd = pre.ImageStd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
if _, err := q.kvHeadCounts(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
|
||||
if len(q.LayerTypes) > 0 {
|
||||
kv := make([]uint32, q.NumHiddenLayers)
|
||||
hasFull := false
|
||||
hasRecurrent := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
layerType := ""
|
||||
if i < uint32(len(q.LayerTypes)) {
|
||||
layerType = q.LayerTypes[i]
|
||||
}
|
||||
if layerType == "full_attention" {
|
||||
kv[i] = q.NumKeyValueHeads
|
||||
hasFull = true
|
||||
} else {
|
||||
hasRecurrent = true
|
||||
}
|
||||
}
|
||||
if !hasFull || !hasRecurrent {
|
||||
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
|
||||
}
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
kv := make([]uint32, q.NumHiddenLayers)
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
kv[i] = q.NumKeyValueHeads
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) ropeSections() []int32 {
|
||||
if len(q.RopeParameters.MropeSection) > 0 {
|
||||
return q.RopeParameters.MropeSection
|
||||
}
|
||||
return q.RopeScaling.MropeSection
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) shouldReorderVHeads() bool {
|
||||
modelType := strings.ToLower(q.ModelType)
|
||||
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, arch := range q.Architectures {
|
||||
arch = strings.ToLower(arch)
|
||||
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Default to qwen3.5 layout for all other qwen3next-family imports.
|
||||
return true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
|
||||
arch := "qwen35"
|
||||
if q.NumExperts > 0 {
|
||||
arch = "qwen35moe"
|
||||
}
|
||||
kv["general.architecture"] = arch
|
||||
kv["tokenizer.ggml.pre"] = "qwen35"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if sections := q.ropeSections(); len(sections) > 0 {
|
||||
kv["mrope_sections"] = sections
|
||||
kv["rope.mrope_section"] = sections
|
||||
kv["rope.dimension_sections"] = sections
|
||||
}
|
||||
if q.RopeParameters.MRopeInterleaved {
|
||||
kv["rope.mrope_interleaved"] = true
|
||||
}
|
||||
|
||||
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.NormTopkProb != nil {
|
||||
kv["norm_top_k_prob"] = *q.NormTopkProb
|
||||
}
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
if q.shouldReorderVHeads() {
|
||||
kv["ssm.v_head_reordered"] = true
|
||||
}
|
||||
if q.FullAttentionInterval > 0 {
|
||||
kv["full_attention_interval"] = q.FullAttentionInterval
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
if headCounts, err := q.kvHeadCounts(); err == nil {
|
||||
kv["attention.head_count_kv"] = headCounts
|
||||
}
|
||||
|
||||
if q.VisionModel.Depth > 0 {
|
||||
kv["vision.block_count"] = q.VisionModel.Depth
|
||||
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
|
||||
kv["vision.num_channels"] = q.VisionModel.InChannels
|
||||
if q.VisionModel.PatchSize > 0 {
|
||||
kv["vision.patch_size"] = q.VisionModel.PatchSize
|
||||
}
|
||||
if q.VisionModel.SpatialMergeSize > 0 {
|
||||
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
|
||||
}
|
||||
if q.VisionModel.RMSNormEps > 0 {
|
||||
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
|
||||
}
|
||||
if q.VisionModel.RopeTheta > 0 {
|
||||
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
|
||||
}
|
||||
if q.VisionModel.TemporalPatchSize > 0 {
|
||||
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
|
||||
}
|
||||
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
|
||||
if q.VisionModel.Size.ShortestEdge > 0 {
|
||||
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
|
||||
}
|
||||
if q.VisionModel.Size.LongestEdge > 0 {
|
||||
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
|
||||
}
|
||||
if len(q.VisionModel.ImageMean) > 0 {
|
||||
kv["vision.image_mean"] = q.VisionModel.ImageMean
|
||||
}
|
||||
if len(q.VisionModel.ImageStd) > 0 {
|
||||
kv["vision.image_std"] = q.VisionModel.ImageStd
|
||||
}
|
||||
}
|
||||
|
||||
if q.ImageTokenID > 0 {
|
||||
kv["image_token_id"] = q.ImageTokenID
|
||||
}
|
||||
if q.VisionStartTokenID > 0 {
|
||||
kv["vision_start_token_id"] = q.VisionStartTokenID
|
||||
}
|
||||
if q.VisionEndTokenID > 0 {
|
||||
kv["vision_end_token_id"] = q.VisionEndTokenID
|
||||
}
|
||||
|
||||
return kv
|
||||
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
|
||||
out = append(out, slices.Collect(splitDim(t, 1,
|
||||
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
|
||||
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
|
||||
))...)
|
||||
|
||||
case strings.Contains(name, ".mlp.experts.down_proj"):
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||
))...)
|
||||
|
||||
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
t.SetRepacker(q.repackSSMA())
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackAttnQKV())
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_dt"):
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF out_proj layout is [out_features, in_features]; reorder columns.
|
||||
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackConv1D())
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() {
|
||||
return data, nil
|
||||
}
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackAttnQKV() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() || len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
rows := int(shape[0])
|
||||
cols := int(shape[1])
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numV := int(q.LinearNumValueHeads)
|
||||
headK := int(q.LinearKeyHeadDim)
|
||||
headV := int(q.LinearValueHeadDim)
|
||||
qDim := headK * numK
|
||||
kDim := headK * numK
|
||||
vDim := headV * numV
|
||||
qkvDim := qDim + kDim + vDim
|
||||
|
||||
switch {
|
||||
case rows == qkvDim:
|
||||
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
|
||||
// reorder only V rows from grouped -> tiled head layout.
|
||||
out := make([]float32, len(data))
|
||||
qkRows := qDim + kDim
|
||||
qkSize := qkRows * cols
|
||||
copy(out[:qkSize], data[:qkSize])
|
||||
|
||||
vStart := qkSize
|
||||
vEnd := vStart + vDim*cols
|
||||
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
copy(out[vEnd:], data[vEnd:])
|
||||
return out, nil
|
||||
|
||||
case cols == qkvDim:
|
||||
// Fallback for already-transposed [in_features, out_features] tensors.
|
||||
out := make([]float32, len(data))
|
||||
copy(out, data)
|
||||
for r := range rows {
|
||||
base := r * cols
|
||||
vStart := base + qDim + kDim
|
||||
vEnd := vStart + vDim
|
||||
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
}
|
||||
return out, nil
|
||||
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackConv1D() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
normShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
normShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
normShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
if len(normShape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
rows := int(normShape[0])
|
||||
cols := int(normShape[1])
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numV := int(q.LinearNumValueHeads)
|
||||
headK := int(q.LinearKeyHeadDim)
|
||||
headV := int(q.LinearValueHeadDim)
|
||||
qkChannels := 2 * headK * numK
|
||||
totalChannels := qkChannels + headV*numV
|
||||
if qkChannels <= 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case rows == totalChannels:
|
||||
// HF layout after squeeze: [channels, kernel]
|
||||
out := make([]float32, len(data))
|
||||
prefix := qkChannels * cols
|
||||
copy(out[:prefix], data[:prefix])
|
||||
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[prefix:], reorderedV)
|
||||
return out, nil
|
||||
case cols == totalChannels:
|
||||
// Fallback for transposed [kernel, channels]
|
||||
out := make([]float32, len(data))
|
||||
copy(out, data)
|
||||
vChannels := totalChannels - qkChannels
|
||||
for r := range rows {
|
||||
base := r * cols
|
||||
vStart := base + qkChannels
|
||||
vEnd := vStart + vChannels
|
||||
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
}
|
||||
return out, nil
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackSSMA() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
if !q.shouldReorderVHeads() {
|
||||
return result, nil
|
||||
}
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
|
||||
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
if dim < 0 {
|
||||
dim += len(dims)
|
||||
}
|
||||
if dim < 0 || dim >= len(dims) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
expected := numKHeads * numVPerK * headDim
|
||||
if dims[dim] != expected {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
newShape := make([]int, 0, len(dims)+2)
|
||||
newShape = append(newShape, dims[:dim]...)
|
||||
newShape = append(newShape, numKHeads, numVPerK, headDim)
|
||||
newShape = append(newShape, dims[dim+1:]...)
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := tt.Reshape(newShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perm := make([]int, len(newShape))
|
||||
for i := range perm {
|
||||
perm[i] = i
|
||||
}
|
||||
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
|
||||
|
||||
tt, err := tensor.Transpose(tt, perm...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
total := 1
|
||||
for _, d := range dims {
|
||||
total *= d
|
||||
}
|
||||
if err := tt.Reshape(total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Vision
|
||||
"model.visual", "v",
|
||||
"patch_embed.proj", "patch_embed",
|
||||
"blocks", "blk",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"deepstack_merger_list", "deepstack_merger",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
// Linear attention (legacy qwen3next)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
|
||||
// Linear attention (qwen35)
|
||||
"linear_attn.in_proj_qkv", "attn_qkv",
|
||||
"linear_attn.in_proj_z", "attn_gate",
|
||||
"linear_attn.in_proj_a", "ssm_alpha",
|
||||
"linear_attn.in_proj_b", "ssm_beta",
|
||||
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
// MoE
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
// Dense FFN
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
563
convert/convert_qwen3next_test.go
Normal file
563
convert/convert_qwen3next_test.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
|
||||
t.Helper()
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tensor.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
numel := 1
|
||||
for _, d := range tensor.Shape {
|
||||
numel *= int(d)
|
||||
}
|
||||
|
||||
values := make([]float32, numel)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
}
|
||||
|
||||
if m.shouldReorderVHeads() {
|
||||
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
Architectures: []string{"Qwen3NextForCausalLM"},
|
||||
},
|
||||
}
|
||||
|
||||
if m.shouldReorderVHeads() {
|
||||
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextKVLegacyConfig(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 8192,
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 2048,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 64,
|
||||
RopeTheta: 1_000_000,
|
||||
RMSNormEPS: 1e-6,
|
||||
|
||||
NumExperts: 8,
|
||||
NumExpertsPerToken: 2,
|
||||
NormTopkProb: boolPtr(true),
|
||||
MoEIntermediateSize: 256,
|
||||
SharedExpertIntermSize: 512,
|
||||
|
||||
FullAttentionInterval: 2,
|
||||
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 64,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 64,
|
||||
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
|
||||
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if _, ok := kv["ssm.v_head_reordered"]; ok {
|
||||
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
|
||||
}
|
||||
if got, want := kv["norm_top_k_prob"], true; got != want {
|
||||
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 4096,
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 2048,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 64,
|
||||
RopeTheta: 1_000_000,
|
||||
RMSNormEPS: 1e-6,
|
||||
NumExperts: 8,
|
||||
NumExpertsPerToken: 2,
|
||||
FullAttentionInterval: 2,
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 64,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 64,
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if _, ok := kv["norm_top_k_prob"]; ok {
|
||||
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35KVFromTextConfig(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
TextConfig: &qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 16384,
|
||||
HiddenSize: 1024,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 4096,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
HeadDim: 128,
|
||||
RMSNormEPS: 1e-6,
|
||||
|
||||
LayerTypes: []string{
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
},
|
||||
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 128,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 128,
|
||||
|
||||
RopeParameters: qwen3NextRopeParams{
|
||||
MRopeInterleaved: true,
|
||||
MropeSection: []int32{11, 11, 10},
|
||||
RopeType: "default",
|
||||
RopeTheta: 10_000_000,
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
},
|
||||
VisionModel: qwen3NextVisionConfig{
|
||||
Depth: 2,
|
||||
HiddenSize: 128,
|
||||
NumHeads: 4,
|
||||
InChannels: 3,
|
||||
PatchSize: 16,
|
||||
SpatialMergeSize: 2,
|
||||
RMSNormEps: 1e-6,
|
||||
RopeTheta: 10_000,
|
||||
TemporalPatchSize: 2,
|
||||
DeepstackVisualIndexes: []int32{1},
|
||||
},
|
||||
ImageTokenID: 1001,
|
||||
VisionStartTokenID: 1002,
|
||||
VisionEndTokenID: 1003,
|
||||
}
|
||||
m.VisionModel.Size.ShortestEdge = 224
|
||||
m.VisionModel.Size.LongestEdge = 4096
|
||||
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "qwen35"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
|
||||
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
|
||||
}
|
||||
|
||||
mrope, ok := kv["mrope_sections"].([]int32)
|
||||
if !ok {
|
||||
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
|
||||
}
|
||||
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
|
||||
}
|
||||
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
|
||||
if !ok {
|
||||
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
|
||||
}
|
||||
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
|
||||
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
|
||||
}
|
||||
|
||||
if got, want := kv["vision.block_count"], uint32(2); got != want {
|
||||
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextReplacements(t *testing.T) {
|
||||
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
|
||||
|
||||
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
|
||||
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
|
||||
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
|
||||
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersVHeads(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_gate.weight",
|
||||
shape: []uint64{4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearKeyHeadDim: 1,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
|
||||
data: []float32{
|
||||
0, 1, // q0
|
||||
2, 3, // q1
|
||||
4, 5, // k0
|
||||
6, 7, // k1
|
||||
10, 11, // v(k0,v0)
|
||||
12, 13, // v(k0,v1)
|
||||
20, 21, // v(k1,v0)
|
||||
22, 23, // v(k1,v1)
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
10, 11, 20, 21, 12, 13, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_out.weight",
|
||||
shape: []uint64{2, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_beta.weight",
|
||||
shape: []uint64{4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearKeyHeadDim: 1,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_conv1d.weight",
|
||||
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
|
||||
data: []float32{
|
||||
0, 1, // q0
|
||||
2, 3, // q1
|
||||
4, 5, // k0
|
||||
6, 7, // k1
|
||||
10, 11, // v(k0,v0)
|
||||
12, 13, // v(k0,v1)
|
||||
20, 21, // v(k1,v0)
|
||||
22, 23, // v(k1,v1)
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
10, 11, 20, 21, 12, 13, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_gate.weight",
|
||||
shape: []uint64{4, 1},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35MoePackedExperts(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
NumHiddenLayers: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.mlp.experts.gate_up_proj",
|
||||
shape: []uint64{2, 4, 3},
|
||||
data: []float32{
|
||||
0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8,
|
||||
9, 10, 11,
|
||||
12, 13, 14,
|
||||
15, 16, 17,
|
||||
18, 19, 20,
|
||||
21, 22, 23,
|
||||
},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.mlp.experts.down_proj",
|
||||
shape: []uint64{2, 5, 3},
|
||||
data: make([]float32, 2*5*3),
|
||||
},
|
||||
})
|
||||
|
||||
get := func(name string) *ggml.Tensor {
|
||||
for _, tensor := range out {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gate := get("blk.0.ffn_gate_exps.weight")
|
||||
if gate == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
|
||||
}
|
||||
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := readTensorData(t, gate), []float32{
|
||||
0, 1, 2, 3, 4, 5,
|
||||
12, 13, 14, 15, 16, 17,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected gate values: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
up := get("blk.0.ffn_up_exps.weight")
|
||||
if up == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
|
||||
}
|
||||
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected up shape: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := readTensorData(t, up), []float32{
|
||||
6, 7, 8, 9, 10, 11,
|
||||
18, 19, 20, 21, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected up values: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
down := get("blk.0.ffn_down_exps.weight")
|
||||
if down == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
|
||||
}
|
||||
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected down shape: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
|
||||
m := &qwen3NextModel{}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ffn_gate_inp_shexp.weight",
|
||||
shape: []uint64{1, 4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
t.Pre = "deepseek-coder"
|
||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||
t.Pre = "qwen2"
|
||||
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
|
||||
t.Pre = "qwen35"
|
||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||
// noop, empty pretokenizer
|
||||
default:
|
||||
|
||||
@@ -386,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen35 pretokenizer",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"pre_tokenizer": {
|
||||
"type": "Sequence",
|
||||
"pretokenizers": [
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||
Pre: "qwen35",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -290,6 +290,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen35", "qwen35moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
@@ -868,7 +869,12 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
arch := f.KV().Architecture()
|
||||
if slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
|
||||
return true
|
||||
}
|
||||
|
||||
if slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -892,6 +898,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen35", "qwen35moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
|
||||
@@ -245,7 +245,22 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
padding := ggufPadding(offset, int64(alignment))
|
||||
llm.tensorOffset = uint64(offset + padding)
|
||||
|
||||
// get file size to validate tensor bounds
|
||||
fileSize, err := rs.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine file size: %w", err)
|
||||
}
|
||||
|
||||
if _, err := rs.Seek(offset, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek back after size check: %w", err)
|
||||
}
|
||||
|
||||
for _, tensor := range llm.tensors {
|
||||
tensorEnd := llm.tensorOffset + tensor.Offset + tensor.Size()
|
||||
if tensorEnd > uint64(fileSize) {
|
||||
return fmt.Errorf("tensor %q offset+size (%d) exceeds file size (%d)", tensor.Name, tensorEnd, fileSize)
|
||||
}
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current offset: %w", err)
|
||||
|
||||
@@ -11,21 +11,21 @@ import (
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
tensorData := make([]byte, 2*3*4) // 6 F32 elements = 24 bytes
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||
}
|
||||
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
@@ -98,4 +98,32 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("truncated_tensor_data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: 0, Shape: []uint64{512, 2}, WriterTo: bytes.NewBuffer(make([]byte, 32))},
|
||||
}
|
||||
|
||||
w, err := os.CreateTemp(t.TempDir(), "truncated_*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if err := WriteGGUF(w, KV{"general.architecture": "test"}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if _, err := Decode(r, -1); err == nil {
|
||||
t.Error("Decode should reject GGUF files where tensor data extends beyond file size")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointCount = 24
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
DefaultCheckpointInterval = int32(1664)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
@@ -2,595 +2,58 @@ package qwen3next
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for full attention layers
|
||||
// - per-sequence conv state for linear attention layers
|
||||
// - per-sequence delta state for linear attention layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
||||
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int // convKernelSize - 1
|
||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
||||
|
||||
// Delta state dimensions
|
||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer delta state buffers (allocated lazily)
|
||||
deltaCtxs map[int]ml.Context
|
||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointDeltaCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(
|
||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||
convDim, convChannels, deltaStateSize int,
|
||||
) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
convDim: convDim,
|
||||
convChannels: convChannels,
|
||||
deltaStateSize: deltaStateSize,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
deltaCtxs: make(map[int]ml.Context),
|
||||
deltaStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: checkpointCountDefault,
|
||||
checkpointMinPos: checkpointMinPosDefault,
|
||||
checkpointInterval: checkpointIntervalDefault,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: deltaStateSize,
|
||||
CheckpointLogPrefix: "qwen3next",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.deltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointDeltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
c.reserveCheckpoints = true
|
||||
c.planCheckpoints(batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero state for newly allocated slots
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = false
|
||||
c.planCheckpoints(batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Zero conv states
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// Zero delta states
|
||||
if len(c.deltaStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
||||
for _, buf := range c.deltaStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.deltaStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// Copy-on-write for recurrent state
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
// If the recurrent slot is shared, detach it before applying a restore.
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
// Removal invalidates recurrent state
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.deltaStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent delta state must stay in F32.
|
||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
||||
c.deltaStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
||||
// DeltaState returns the delta state for current batch sequences as
|
||||
// [headVDim, headVDim*numVHeads, nSeqs].
|
||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
||||
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
|
||||
}
|
||||
|
||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
|
||||
// Keep qwen3next behavior for partial mid-sequence removals.
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
return c.Recurrent.Remove(seq, beginIndex, endIndex)
|
||||
}
|
||||
|
||||
@@ -1,498 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
checkpointCountDefault = 32
|
||||
checkpointMinPosDefault = int32(16)
|
||||
checkpointIntervalDefault = int32(1280)
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
delta map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.delta {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.deltaStates {
|
||||
if entry.delta == nil || entry.delta[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.delta != nil {
|
||||
if dstEntry.delta == nil {
|
||||
dstEntry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.delta {
|
||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.delta == nil {
|
||||
entry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.delta[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointDeltaCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
||||
entry.delta[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointDelta(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestBackend(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
_ = f.Close()
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 1, 1)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
||||
backend := newTestBackend(t)
|
||||
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.slotForSeq[2] = 0
|
||||
cache.refCount[0] = 2
|
||||
cache.refCount[1] = 0
|
||||
cache.freeSlots = []int{1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[1] != 1 {
|
||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[2] != 0 {
|
||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
||||
}
|
||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
||||
}
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
// Only set conv checkpoint, not delta - making it incomplete
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.delta is not set, so checkpoint is incomplete
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Don't set convStates/deltaStates - with no layers to check,
|
||||
// entryComplete will return true as long as entry.pos >= 0
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
// Test that restoreComplete returns true when no layers need checkpoints
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
// Fill the buffer
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Create fake tensor data in the first entry's maps
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
||||
|
||||
// Record another entry, which should wrap around and overwrite entry 0
|
||||
store.record(40)
|
||||
|
||||
// Verify the maps are still present (we reuse tensors)
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].delta == nil {
|
||||
t.Fatalf("expected delta map to be preserved on reuse")
|
||||
}
|
||||
|
||||
// Verify the new position was recorded
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
// Test behavior when buffer is exactly at capacity
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
// Verify both checkpoints are accessible
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
// Test behavior with zero-size buffer
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
// Test pruning that removes all checkpoints
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Prune everything by setting threshold below all positions
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
// When all checkpoints are pruned, lastPos is reset to -1
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,9 @@ type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
@@ -96,7 +98,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
@@ -106,24 +107,40 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
var beta ml.Tensor
|
||||
var alpha ml.Tensor
|
||||
switch {
|
||||
case gdn.SSMBetaAlpha != nil:
|
||||
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Split beta and alpha
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Reshape to merge head dimensions
|
||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
// Keep beta layout consistent with qwen35 and llama.cpp:
|
||||
// [1, numVHeads, nSeqTokens, nSeqs]
|
||||
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
|
||||
// qwen35 path: beta/alpha are separate projections.
|
||||
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
@@ -172,16 +189,20 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
|
||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||
if numKHeads != numVHeads {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
if opts.vHeadReordered {
|
||||
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||
} else {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
}
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
@@ -189,7 +210,9 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
if opts.masks == nil {
|
||||
opts.masks = createMasks(ctx)
|
||||
}
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
}
|
||||
|
||||
@@ -310,9 +333,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
|
||||
|
||||
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
// Match llama.cpp delta-net-base layout:
|
||||
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
|
||||
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Compute padding
|
||||
@@ -324,7 +348,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
q = q.Pad(ctx, 0, pad, 0, 0)
|
||||
k = k.Pad(ctx, 0, pad, 0, 0)
|
||||
v = v.Pad(ctx, 0, pad, 0, 0)
|
||||
gate = gate.Pad(ctx, pad, 0, 0, 0)
|
||||
gate = gate.Pad(ctx, 0, pad, 0, 0)
|
||||
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
||||
}
|
||||
|
||||
@@ -344,10 +368,12 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
|
||||
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
// Reshape gate and cumsum over chunk axis.
|
||||
// [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs]
|
||||
gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
|
||||
// g_cumsum = cumsum(gate)
|
||||
gCumsum := gate.CumSum(ctx)
|
||||
gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx)
|
||||
|
||||
// Compute decay mask
|
||||
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/model/models/qwen3vl"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
@@ -41,10 +45,15 @@ type Options struct {
|
||||
ssmNGroup int // num_k_heads
|
||||
ssmDtRank int // num_v_heads
|
||||
convKernelSize int // SSM conv kernel size
|
||||
vHeadReordered bool
|
||||
|
||||
// Per-layer type from GGUF metadata
|
||||
isRecurrent []bool
|
||||
|
||||
// RoPE mode config (used by qwen35/qwen35moe)
|
||||
mropeSections []int
|
||||
mropeInterleaved bool
|
||||
|
||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||
masks *Masks
|
||||
}
|
||||
@@ -54,7 +63,17 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
var opts []func(*rope.Options)
|
||||
if len(o.mropeSections) > 0 {
|
||||
if o.mropeInterleaved {
|
||||
opts = append(opts, rope.WithInterleaveMRoPE(o.mropeSections))
|
||||
} else {
|
||||
opts = append(opts, rope.WithMRoPE(o.mropeSections))
|
||||
}
|
||||
} else {
|
||||
opts = append(opts, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
@@ -214,20 +233,190 @@ type Model struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
Vision *qwen3vl.VisionModel `gguf:"v"`
|
||||
|
||||
ImageProcessor *qwen3vl.ImageProcessor
|
||||
|
||||
*Options
|
||||
|
||||
positionCache []int32
|
||||
imageToken int32
|
||||
visionStart int32
|
||||
visionEnd int32
|
||||
spatialMergeSize uint32
|
||||
}
|
||||
|
||||
func (m *Model) mapPosition(id int32) int32 {
|
||||
if id < int32(len(m.positionCache)) {
|
||||
return m.positionCache[id]
|
||||
}
|
||||
if len(m.positionCache) > 0 {
|
||||
return id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (m *Model) buildPositions(ctx ml.Context, batch input.Batch) ml.Tensor {
|
||||
if len(m.mropeSections) == 0 {
|
||||
return ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
}
|
||||
|
||||
// ggml MRoPE expects [time, height, width, extra] for each token.
|
||||
positionSlice := [][]int32{
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
}
|
||||
|
||||
for i, id := range batch.Positions {
|
||||
p := m.mapPosition(id)
|
||||
positionSlice[0][i] = p
|
||||
positionSlice[1][i] = p
|
||||
positionSlice[2][i] = p
|
||||
}
|
||||
|
||||
if m.Vision != nil {
|
||||
for _, mi := range batch.Multimodal {
|
||||
grid, ok := mi.Multimodal[0].Data.(*qwen3vl.Grid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
w := max(1, grid.Width/int(m.spatialMergeSize))
|
||||
for i := range mi.Multimodal[0].Tensor.Dim(1) {
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if m.Vision == nil || m.ImageProcessor == nil || len(m.Vision.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, grid, err := m.ImageProcessor.ProcessImage(ctx, img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs, deepstackVisualEmbeds := m.Vision.Forward(ctx, pixelValues, grid)
|
||||
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
|
||||
for i := range deepstackVisualEmbeds {
|
||||
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
m.positionCache = m.positionCache[:0]
|
||||
var result []*input.Input
|
||||
appendInput := func(inp *input.Input, position int32) {
|
||||
result = append(result, inp)
|
||||
m.positionCache = append(m.positionCache, position)
|
||||
}
|
||||
|
||||
var p int32
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
appendInput(inp, p)
|
||||
p++
|
||||
continue
|
||||
}
|
||||
|
||||
grid := inp.Multimodal[0].Data.(*qwen3vl.Grid)
|
||||
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
appendInput(&input.Input{
|
||||
Token: m.visionStart,
|
||||
SameBatch: tokensPerGrid + 1,
|
||||
}, p)
|
||||
p++
|
||||
|
||||
appendInput(&input.Input{
|
||||
Token: m.imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
}, p)
|
||||
|
||||
for range tokensPerGrid - 1 {
|
||||
appendInput(&input.Input{
|
||||
Token: m.imageToken,
|
||||
}, p)
|
||||
}
|
||||
|
||||
gridSpan := max(grid.Width/int(m.spatialMergeSize), grid.Height/int(m.spatialMergeSize))
|
||||
p = p + int32(gridSpan)
|
||||
appendInput(&input.Input{
|
||||
Token: m.visionEnd,
|
||||
}, p)
|
||||
p++
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
positions := m.buildPositions(ctx, batch)
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
hiddenStates = hiddenStates.Duplicate(ctx)
|
||||
|
||||
var deepstackVisualEmbeds []ml.Tensor
|
||||
for _, mi := range batch.Multimodal {
|
||||
visionOutputs := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
|
||||
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
|
||||
}
|
||||
for i, mm := range mi.Multimodal[1:] {
|
||||
if deepstackVisualEmbeds[i] == nil {
|
||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||
}
|
||||
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||
}
|
||||
}
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
m.Options.masks = nil
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i < len(deepstackVisualEmbeds) {
|
||||
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
// Create masks once per forward pass
|
||||
m.Options.masks = createMasks(ctx)
|
||||
// Masks are allocated lazily only for chunked recurrent prefill.
|
||||
m.Options.masks = nil
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
@@ -249,10 +438,17 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
shift = shift.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
||||
}
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
var (
|
||||
_ model.Model = (*Model)(nil)
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
@@ -303,6 +499,22 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
}
|
||||
|
||||
mropeSections := c.Ints("mrope_sections", nil)
|
||||
if len(mropeSections) == 0 {
|
||||
mropeSections = c.Ints("rope.mrope_section", nil)
|
||||
}
|
||||
if len(mropeSections) == 0 {
|
||||
mropeSections = c.Ints("rope.dimension_sections", nil)
|
||||
}
|
||||
if len(mropeSections) > 4 {
|
||||
mropeSections = mropeSections[:4]
|
||||
}
|
||||
|
||||
ropeType := c.String("rope.scaling.type")
|
||||
if ropeType == "" {
|
||||
ropeType = c.String("rope.type")
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
@@ -318,7 +530,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeType: ropeType,
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
@@ -331,7 +543,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
if !yield(int(section)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
@@ -353,6 +574,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||
}
|
||||
|
||||
var vision *qwen3vl.VisionModel
|
||||
var imageProcessor *qwen3vl.ImageProcessor
|
||||
if c.Uint("vision.block_count", 0) > 0 {
|
||||
vision = qwen3vl.NewVisionModel(c)
|
||||
processor := qwen3vl.NewImageProcessor(c)
|
||||
imageProcessor = &processor
|
||||
}
|
||||
|
||||
spatialMergeSize := c.Uint("vision.spatial_merge_size", 2)
|
||||
if spatialMergeSize == 0 {
|
||||
spatialMergeSize = 2
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
@@ -371,8 +605,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
Layers: layers,
|
||||
Vision: vision,
|
||||
ImageProcessor: imageProcessor,
|
||||
Options: opts,
|
||||
imageToken: int32(c.Uint("image_token_id", 151655)),
|
||||
visionStart: int32(c.Uint("vision_start_token_id", 151652)),
|
||||
visionEnd: int32(c.Uint("vision_end_token_id", 151653)),
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||
@@ -380,5 +620,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen35", New)
|
||||
model.Register("qwen35moe", New)
|
||||
model.Register("qwen3next", New)
|
||||
}
|
||||
|
||||
101
model/models/qwen3next/model_posttokenize_test.go
Normal file
101
model/models/qwen3next/model_posttokenize_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
type fakeTensor struct {
|
||||
*ggml.Tensor
|
||||
dims []int
|
||||
}
|
||||
|
||||
func (t *fakeTensor) Dim(i int) int {
|
||||
return t.dims[i]
|
||||
}
|
||||
|
||||
func makeImageInput(hash uint64, width, height, tokens int) *input.Input {
|
||||
return &input.Input{
|
||||
Multimodal: []input.Multimodal{{
|
||||
Tensor: &fakeTensor{dims: []int{1, tokens, 1, 1}},
|
||||
Data: &qwen3vl.Grid{Width: width, Height: height},
|
||||
}},
|
||||
MultimodalHash: hash,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeMultiImageSpans(t *testing.T) {
|
||||
m := &Model{
|
||||
imageToken: 10,
|
||||
visionStart: 11,
|
||||
visionEnd: 12,
|
||||
spatialMergeSize: 2,
|
||||
}
|
||||
|
||||
inputs := []*input.Input{
|
||||
{Token: 100},
|
||||
makeImageInput(1, 8, 4, 4),
|
||||
makeImageInput(2, 4, 8, 4),
|
||||
{Token: 200},
|
||||
}
|
||||
|
||||
got, err := m.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize() error = %v", err)
|
||||
}
|
||||
|
||||
want := []struct {
|
||||
token int32
|
||||
hash uint64
|
||||
sameBatch int
|
||||
hasMM bool
|
||||
}{
|
||||
{token: 100},
|
||||
{token: 11, sameBatch: 5},
|
||||
{token: 10, hash: 1, hasMM: true},
|
||||
{token: 10},
|
||||
{token: 10},
|
||||
{token: 10},
|
||||
{token: 12},
|
||||
{token: 11, sameBatch: 5},
|
||||
{token: 10, hash: 2, hasMM: true},
|
||||
{token: 10},
|
||||
{token: 10},
|
||||
{token: 10},
|
||||
{token: 12},
|
||||
{token: 200},
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
|
||||
for i := range want {
|
||||
if got[i].Token != want[i].token {
|
||||
t.Fatalf("got[%d].Token = %d, want %d", i, got[i].Token, want[i].token)
|
||||
}
|
||||
if got[i].MultimodalHash != want[i].hash {
|
||||
t.Fatalf("got[%d].MultimodalHash = %d, want %d", i, got[i].MultimodalHash, want[i].hash)
|
||||
}
|
||||
if got[i].SameBatch != want[i].sameBatch {
|
||||
t.Fatalf("got[%d].SameBatch = %d, want %d", i, got[i].SameBatch, want[i].sameBatch)
|
||||
}
|
||||
hasMM := len(got[i].Multimodal) > 0
|
||||
if hasMM != want[i].hasMM {
|
||||
t.Fatalf("got[%d].hasMM = %v, want %v", i, hasMM, want[i].hasMM)
|
||||
}
|
||||
}
|
||||
|
||||
wantPositions := []int32{0, 1, 2, 2, 2, 2, 6, 7, 8, 8, 8, 8, 12, 13}
|
||||
if len(m.positionCache) != len(wantPositions) {
|
||||
t.Fatalf("len(positionCache) = %d, want %d", len(m.positionCache), len(wantPositions))
|
||||
}
|
||||
for i := range wantPositions {
|
||||
if m.positionCache[i] != wantPositions[i] {
|
||||
t.Fatalf("positionCache[%d] = %d, want %d", i, m.positionCache[i], wantPositions[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,8 +24,8 @@ type ImageProcessor struct {
|
||||
imageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
// NewImageProcessor creates a new image processor with default values.
|
||||
func NewImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
|
||||
@@ -56,60 +56,46 @@ var (
|
||||
tokenVisionEnd int32 = 151653
|
||||
)
|
||||
|
||||
type modelInput struct {
|
||||
*input.Input
|
||||
position int32
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
m.positionCache = m.positionCache[:0]
|
||||
return slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for i := range inputs {
|
||||
s := []modelInput{{Input: inputs[i]}}
|
||||
if mm := inputs[i].Multimodal; mm != nil {
|
||||
t := mm[0].Tensor
|
||||
s = slices.Repeat([]modelInput{
|
||||
{
|
||||
position: int32(i + 1),
|
||||
Input: &input.Input{Token: tokenVision},
|
||||
},
|
||||
}, t.Dim(1)+1+1)
|
||||
var result []*input.Input
|
||||
appendInput := func(inp *input.Input, position int32) {
|
||||
result = append(result, inp)
|
||||
m.positionCache = append(m.positionCache, position)
|
||||
}
|
||||
|
||||
s[0] = modelInput{
|
||||
Input: &input.Input{Token: tokenVisionStart},
|
||||
position: int32(i),
|
||||
}
|
||||
|
||||
s[len(s)-1] = modelInput{
|
||||
Input: &input.Input{Token: tokenVisionEnd},
|
||||
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
|
||||
}
|
||||
|
||||
s[1] = modelInput{
|
||||
Input: &input.Input{
|
||||
Token: tokenVision,
|
||||
Multimodal: inputs[i].Multimodal,
|
||||
MultimodalHash: inputs[i].MultimodalHash,
|
||||
SameBatch: t.Dim(1),
|
||||
},
|
||||
position: int32(i + 1),
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range s {
|
||||
position := e.position
|
||||
if position == 0 && len(m.positionCache) > 0 {
|
||||
position = m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
|
||||
m.positionCache = append(m.positionCache, position)
|
||||
if !yield(e.Input) {
|
||||
return
|
||||
}
|
||||
}
|
||||
var p int32
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
appendInput(inp, p)
|
||||
p++
|
||||
continue
|
||||
}
|
||||
}), nil
|
||||
|
||||
grid := inp.Multimodal[0].Data.(*Grid)
|
||||
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
appendInput(&input.Input{Token: tokenVisionStart}, p)
|
||||
p++
|
||||
|
||||
appendInput(&input.Input{
|
||||
Token: tokenVision,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: tokensPerGrid,
|
||||
}, p)
|
||||
|
||||
for range tokensPerGrid - 1 {
|
||||
appendInput(&input.Input{Token: tokenVision}, p)
|
||||
}
|
||||
|
||||
p = p + int32(grid.Width/m.spatialMergeSize)
|
||||
appendInput(&input.Input{Token: tokenVisionEnd}, p)
|
||||
p++
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
@@ -143,9 +129,13 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
}
|
||||
|
||||
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
|
||||
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
|
||||
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
|
||||
}
|
||||
for i, mm := range mi.Multimodal[1:] {
|
||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||
if deepstackVisualEmbeds[i] == nil {
|
||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||
}
|
||||
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||
}
|
||||
}
|
||||
@@ -189,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: NewVisionModel(c),
|
||||
ImageProcessor: NewImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
||||
|
||||
@@ -238,8 +238,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
return hiddenStates, deepstackStates
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
// NewVisionModel creates a new instance of the Qwen vision model.
|
||||
func NewVisionModel(c fs.Config) *VisionModel {
|
||||
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
|
||||
model := &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||
|
||||
@@ -49,6 +49,8 @@ func ParserForName(name string) Parser {
|
||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3.5":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -59,6 +59,7 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
{"qwen3-coder"},
|
||||
{"lfm2"},
|
||||
{"lfm2-thinking"},
|
||||
{"qwen3.5"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
|
||||
@@ -145,3 +145,26 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("Hello! How can I help you today?", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Hello! How can I help you today?" {
|
||||
t.Fatalf("expected content %q, got %q", "Hello! How can I help you today?", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package renderers
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -192,21 +193,25 @@ func lfm2RenderToolCalls(calls []api.ToolCall) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
|
||||
func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int) string {
|
||||
content := lfm2RenderContent(message.Content, r.useImgTags)
|
||||
if len(message.Images) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
// chatPrompt may already have inserted [img] / [img-n] placeholders.
|
||||
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
placeholder := lfm2ImagePlaceholder(r.useImgTags)
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
if r.useImgTags {
|
||||
for i := range message.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
||||
}
|
||||
} else {
|
||||
placeholder := lfm2ImagePlaceholder(false)
|
||||
if strings.Contains(content, placeholder) {
|
||||
return content
|
||||
}
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
}
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
@@ -262,6 +267,11 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
}
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i := range startIdx {
|
||||
imageOffset += len(messages[i].Images)
|
||||
}
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
lastMessage := i == len(messages)-1
|
||||
@@ -271,7 +281,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
sb.WriteString(message.Role)
|
||||
sb.WriteString("\n")
|
||||
|
||||
content := r.renderMessageContent(message)
|
||||
content := r.renderMessageContent(message, imageOffset)
|
||||
imageOffset += len(message.Images)
|
||||
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
|
||||
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
|
||||
content = strings.TrimSpace(content[idx+len("</think>"):])
|
||||
|
||||
@@ -236,16 +236,6 @@ func TestLFM2Renderer_Images(t *testing.T) {
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_indexed_img_placeholder_not_duplicated",
|
||||
renderer: &LFM2Renderer{useImgTags: true},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "[img-0]Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -9,10 +10,11 @@ import (
|
||||
type Qwen3VLRenderer struct {
|
||||
isThinking bool
|
||||
|
||||
useImgTags bool
|
||||
emitEmptyThinkOnNoThink bool
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
||||
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
@@ -20,7 +22,8 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
||||
// model backends, and so we should eventually parameterize this or
|
||||
// only output a placeholder such as [img]
|
||||
if r.useImgTags {
|
||||
subSb.WriteString("[img]")
|
||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
@@ -28,12 +31,17 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
||||
// TODO: support videos
|
||||
|
||||
subSb.WriteString(content.Content)
|
||||
return subSb.String()
|
||||
return subSb.String(), imageOffset
|
||||
}
|
||||
|
||||
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
isThinking := r.isThinking
|
||||
if think != nil {
|
||||
isThinking = think.Bool()
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
@@ -57,7 +65,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
message := messages[i]
|
||||
if multiStepTool && message.Role == "user" {
|
||||
// Check if content starts with <tool_response> and ends with </tool_response>
|
||||
content := r.renderContent(message)
|
||||
content, _ := r.renderContent(message, 0)
|
||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||
multiStepTool = false
|
||||
lastQueryIndex = i
|
||||
@@ -65,8 +73,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
}
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
content := r.renderContent(message)
|
||||
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextImageOffset
|
||||
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
@@ -76,13 +86,13 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
} else if message.Role == "assistant" {
|
||||
contentReasoning := ""
|
||||
|
||||
if r.isThinking {
|
||||
if isThinking {
|
||||
if message.Thinking != "" {
|
||||
contentReasoning = message.Thinking
|
||||
}
|
||||
}
|
||||
|
||||
if r.isThinking && i > lastQueryIndex {
|
||||
if isThinking && i > lastQueryIndex {
|
||||
if i == len(messages)-1 || contentReasoning != "" {
|
||||
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
|
||||
if content != "" {
|
||||
@@ -125,8 +135,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
// prefill at the end
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
if r.isThinking {
|
||||
if isThinking {
|
||||
sb.WriteString("<think>\n")
|
||||
} else if r.emitEmptyThinkOnNoThink {
|
||||
sb.WriteString("<think>\n\n</think>\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ Let me analyze this image.`,
|
||||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img]Describe this image.<|im_end|>
|
||||
[img-0]Describe this image.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
@@ -123,7 +123,7 @@ Let me analyze this image.`,
|
||||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img][img]Describe these images.<|im_end|>
|
||||
[img-0][img-1]Describe these images.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -370,3 +371,74 @@ func TestFormatToolCallArgumentThinkingVL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLRendererThinkOverride(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
renderThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(renderThinking, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected default thinking renderer to emit <think>, got:\n%s", renderThinking)
|
||||
}
|
||||
|
||||
renderNonThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(renderNonThinking, "<think>") {
|
||||
t.Fatalf("expected think=false override to suppress <think>, got:\n%s", renderNonThinking)
|
||||
}
|
||||
|
||||
renderForcedThinking, err := (&Qwen3VLRenderer{isThinking: false}).Render(msgs, nil, &api.ThinkValue{Value: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(renderForcedThinking, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected think=true override to emit <think>, got:\n%s", renderForcedThinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLRendererThinkOverrideWithExplicitNoThinkPrefill(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
renderNonThinking, err := (&Qwen3VLRenderer{
|
||||
isThinking: true,
|
||||
emitEmptyThinkOnNoThink: true,
|
||||
}).Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !strings.Contains(renderNonThinking, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||
t.Fatalf("expected explicit think=false prefill block, got:\n%s", renderNonThinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRendererNameNoThinkBehaviorSplit(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
thinkFalse := &api.ThinkValue{Value: false}
|
||||
|
||||
qwen35Rendered, err := RenderWithRenderer("qwen3.5", msgs, nil, thinkFalse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(qwen35Rendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||
t.Fatalf("expected qwen3.5 renderer to emit explicit no-think prefill, got:\n%s", qwen35Rendered)
|
||||
}
|
||||
|
||||
qwen3VLRendered, err := RenderWithRenderer("qwen3-vl-thinking", msgs, nil, thinkFalse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(qwen3VLRendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||
t.Fatalf("expected qwen3-vl-thinking renderer to keep legacy non-empty no-think behavior, got:\n%s", qwen3VLRendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
|
||||
case "qwen3-vl-thinking":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3.5":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
|
||||
@@ -29,17 +29,27 @@ func TestRegisterCustomRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuiltInRendererStillWorks(t *testing.T) {
|
||||
// Test that qwen3-coder still works
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{name: "qwen3-coder"},
|
||||
{name: "qwen3.5"},
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result from qwen3-coder renderer")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := RenderWithRenderer(tt.name, messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == "" {
|
||||
t.Fatalf("expected non-empty result from %s renderer", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
}
|
||||
images = append(images, imgData)
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||
if !strings.Contains(prompt, "[img]") {
|
||||
@@ -93,8 +98,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
} else {
|
||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||
}
|
||||
|
||||
images = append(images, imgData)
|
||||
}
|
||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
@@ -330,3 +331,38 @@ func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "what do these photos have in common?",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
|
||||
},
|
||||
}
|
||||
originalContent := msgs[0].Content
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msgs[0].Content != originalContent {
|
||||
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
|
||||
}
|
||||
|
||||
if got, want := len(images), 3; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if prompt == "" {
|
||||
t.Fatal("prompt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
@@ -33,6 +34,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
}
|
||||
if uint64(len(data)) < q.from.Size() {
|
||||
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||
@@ -58,7 +62,7 @@ func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
|
||||
switch {
|
||||
// Full attention
|
||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||
@@ -79,6 +83,10 @@ func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
// SSM
|
||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_beta.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
@@ -287,8 +295,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3nextQuantType(name); ok {
|
||||
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3LinearAttnQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +166,60 @@ func TestGetTensorNewType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
arch string
|
||||
tensor string
|
||||
fileType fsggml.FileType
|
||||
expected fsggml.TensorType
|
||||
}{
|
||||
{
|
||||
name: "qwen35_beta",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_beta.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35_alpha",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_alpha.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35moe_attn_qkv",
|
||||
arch: "qwen35moe",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "non_qwen35_falls_back",
|
||||
arch: "foo",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
kv := fsggml.KV{"general.architecture": tt.arch}
|
||||
got := newType(&fsggml.Tensor{
|
||||
Name: tt.tensor,
|
||||
Shape: []uint64{256, 256},
|
||||
Kind: uint32(fsggml.TensorTypeF16),
|
||||
}, kv, &quantizeState{}, tt.fileType)
|
||||
|
||||
if got != tt.expected {
|
||||
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -173,6 +227,7 @@ func TestQuantizeModel(t *testing.T) {
|
||||
tensors []*fsggml.Tensor
|
||||
newType string
|
||||
expectedTensorTypes map[string]fsggml.TensorType
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "f16_q4_k",
|
||||
@@ -253,6 +308,36 @@ func TestQuantizeModel(t *testing.T) {
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "f16_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -264,6 +349,9 @@ func TestQuantizeModel(t *testing.T) {
|
||||
}
|
||||
defer fp.Close()
|
||||
meta, err := fsggml.Decode(fp, -1)
|
||||
if tt.expectErr && err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
@@ -283,6 +371,12 @@ func TestQuantizeModel(t *testing.T) {
|
||||
}
|
||||
|
||||
err = quantize(fp, tmp, meta, ftype, progress)
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected quantize to return an error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error during quantize: %s", err)
|
||||
}
|
||||
|
||||
@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
@@ -296,19 +296,13 @@ func normalizeQuantType(quantize string) string {
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
isStackedExpert := strings.Contains(name, ".mlp.experts.gate_up_proj") || strings.Contains(name, ".mlp.experts.down_proj")
|
||||
|
||||
// Use basic name-based check first
|
||||
if !isStackedExpert && !ShouldQuantize(name, "") {
|
||||
if !ShouldQuantize(name, "") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Quantize 2D linear tensors by default. qwen3.5 stacked expert tensors are
|
||||
// also eligible even though they are stored as 3D [experts, out, in].
|
||||
if !isStackedExpert && len(shape) != 2 {
|
||||
return ""
|
||||
}
|
||||
if isStackedExpert && len(shape) != 3 {
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -378,10 +372,9 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
}
|
||||
|
||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
||||
// Matches nested and non-nested LLM prefixes and both per-expert ".weight"
|
||||
// tensors and qwen3.5 stacked expert tensors without ".weight".
|
||||
// Captures: model(.language_model(.model)?).layers.{L}.mlp.experts or .shared_experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model(?:\.language_model(?:\.model)?)?\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*(?:\.weight)?$`)
|
||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
||||
// Captures: model.layers.{L}.mlp.experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
||||
|
||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||
// For example:
|
||||
|
||||
@@ -557,9 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
// qwen3.5 stacked expert tensors are an exception: 3D [experts, out, in]
|
||||
{"qwen3.5 stacked gate_up_proj", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int4", true},
|
||||
{"qwen3.5 stacked down_proj", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int4", true},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
@@ -589,42 +586,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_StackedExperts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
shape: []int32{256, 1024, 2048},
|
||||
quantize: "int4",
|
||||
want: "int4",
|
||||
},
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.down_proj",
|
||||
shape: []int32{256, 2048, 512},
|
||||
quantize: "int4",
|
||||
want: "int8",
|
||||
},
|
||||
{
|
||||
name: "model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
shape: []int32{256, 1024, 2050}, // not divisible by 32
|
||||
quantize: "int4",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetTensorQuantization(tt.name, tt.shape, tt.quantize); got != tt.want {
|
||||
t.Fatalf("GetTensorQuantization(%q, %v, %q) = %q, want %q", tt.name, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpertGroupPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -634,18 +595,10 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.0.up_proj.weight", "model.language_model.model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.gate_up_proj", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.layers.0.mlp.experts.down_proj", "model.language_model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.gate_up_proj", "model.language_model.model.layers.0.mlp.experts"},
|
||||
{"model.language_model.model.layers.0.mlp.experts.down_proj", "model.language_model.model.layers.0.mlp.experts"},
|
||||
|
||||
// Shared expert tensors should return their own group prefix
|
||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||
{"model.language_model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.layers.1.mlp.shared_experts"},
|
||||
{"model.language_model.model.layers.1.mlp.shared_experts.down_proj.weight", "model.language_model.model.layers.1.mlp.shared_experts"},
|
||||
|
||||
// Non-expert tensors should return empty string
|
||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||
|
||||
@@ -5,341 +5,62 @@ package mlxrunner
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
const defaultPromptCacheBranches = 1
|
||||
|
||||
// HybridCacheEntry stores a single prompt branch with mixed cache types
|
||||
// (e.g. KV + recurrent caches) and coordinates shared operations across them.
|
||||
type HybridCacheEntry struct {
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
}
|
||||
|
||||
// CacheEntry is kept as an alias for the current single-entry runner path.
|
||||
// Future multi-entry cache stores should prefer HybridCacheEntry directly.
|
||||
type CacheEntry = HybridCacheEntry
|
||||
|
||||
func promptCacheBranchLimit() int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultPromptCacheBranches
|
||||
}
|
||||
|
||||
func cloneTokens(tokens []int32) []int32 {
|
||||
out := make([]int32, len(tokens))
|
||||
copy(out, tokens)
|
||||
return out
|
||||
}
|
||||
|
||||
func equalTokens(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Runner) cacheStore() []*HybridCacheEntry {
|
||||
if len(r.caches) == 0 && r.cache != nil {
|
||||
r.caches = []*HybridCacheEntry{r.cache}
|
||||
}
|
||||
if r.cache == nil && len(r.caches) > 0 {
|
||||
r.cache = r.caches[0]
|
||||
}
|
||||
return r.caches
|
||||
}
|
||||
|
||||
func (r *Runner) setCacheStore(entries []*HybridCacheEntry) {
|
||||
r.caches = entries
|
||||
if len(entries) == 0 {
|
||||
r.cache = nil
|
||||
return
|
||||
}
|
||||
r.cache = entries[0]
|
||||
}
|
||||
|
||||
func (r *Runner) touchCacheEntry(idx int) {
|
||||
if idx <= 0 || idx >= len(r.caches) {
|
||||
return
|
||||
}
|
||||
e := r.caches[idx]
|
||||
copy(r.caches[1:idx+1], r.caches[:idx])
|
||||
r.caches[0] = e
|
||||
r.cache = r.caches[0]
|
||||
}
|
||||
|
||||
func (r *Runner) bestCacheEntry(tokens []int32) (idx int, prefix int) {
|
||||
bestIdx, bestPrefix := -1, 0
|
||||
for i, e := range r.cacheStore() {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
p := e.PrefixLen(tokens)
|
||||
if p > bestPrefix {
|
||||
bestIdx, bestPrefix = i, p
|
||||
}
|
||||
}
|
||||
return bestIdx, bestPrefix
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) PrefixLen(tokens []int32) int {
|
||||
if e == nil {
|
||||
return 0
|
||||
}
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(e.Tokens) && tokens[prefix] == e.Tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) Free() {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c != nil {
|
||||
c.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) cachesSlice() []cache.Cache {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Caches
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) cachesCanTrim() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if !c.CanTrim() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) TrimToPrefix(prefix int) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || !c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
trim := c.Offset() - prefix
|
||||
if trim > 0 {
|
||||
c.Trim(trim)
|
||||
}
|
||||
}
|
||||
if prefix < len(e.Tokens) {
|
||||
e.Tokens = e.Tokens[:prefix]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *HybridCacheEntry) RestoreToPrefix(target int) (int, bool) {
|
||||
if e == nil {
|
||||
return 0, false
|
||||
}
|
||||
restorePos := -1
|
||||
sawNonTrimmable := false
|
||||
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
sawNonTrimmable = true
|
||||
|
||||
restorer, ok := c.(cache.CheckpointRestorer)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
pos, ok := restorer.BestCheckpoint(target)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
if restorePos < 0 {
|
||||
restorePos = pos
|
||||
continue
|
||||
}
|
||||
if pos != restorePos {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
if !sawNonTrimmable || restorePos < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
e.TrimToPrefix(restorePos)
|
||||
for _, c := range e.Caches {
|
||||
if c == nil || c.CanTrim() {
|
||||
continue
|
||||
}
|
||||
restorer, ok := c.(cache.CheckpointRestorer)
|
||||
if !ok || !restorer.RestoreCheckpoint(restorePos) {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
if restorePos < len(e.Tokens) {
|
||||
e.Tokens = e.Tokens[:restorePos]
|
||||
}
|
||||
return restorePos, true
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
entries := r.cacheStore()
|
||||
if len(entries) == 0 {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
branchLimit := promptCacheBranchLimit()
|
||||
idx, prefix := r.bestCacheEntry(tokens)
|
||||
if idx < 0 {
|
||||
if branchLimit <= 1 && len(entries) == 1 && entries[0] != nil {
|
||||
entries[0].Free()
|
||||
r.setCacheStore(nil)
|
||||
}
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
if idx > 0 {
|
||||
r.touchCacheEntry(idx)
|
||||
}
|
||||
base := r.cache
|
||||
if base == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
working := base
|
||||
forked := false
|
||||
if branchLimit > 1 && prefix > 0 {
|
||||
working = base.Clone()
|
||||
forked = true
|
||||
// Find longest common prefix
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
if !forked && branchLimit <= 1 {
|
||||
base.Free()
|
||||
r.setCacheStore(nil)
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
}
|
||||
r.cache = nil
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(working.Tokens):
|
||||
if !working.cachesCanTrim() {
|
||||
if restorePos, ok := working.RestoreToPrefix(prefix); ok {
|
||||
slog.Info("Cache restore", "total", len(tokens), "matched", prefix, "restored", restorePos, "left", len(tokens[restorePos:]))
|
||||
return working.cachesSlice(), tokens[restorePos:]
|
||||
}
|
||||
|
||||
if forked {
|
||||
working.Free()
|
||||
} else if branchLimit <= 1 {
|
||||
base.Free()
|
||||
r.setCacheStore(nil)
|
||||
}
|
||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
}
|
||||
working.TrimToPrefix(prefix)
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return working.cachesSlice(), tokens[prefix:]
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
entry := &HybridCacheEntry{
|
||||
Tokens: cloneTokens(tokens),
|
||||
Caches: caches,
|
||||
}
|
||||
|
||||
branchLimit := promptCacheBranchLimit()
|
||||
if branchLimit <= 1 {
|
||||
r.setCacheStore([]*HybridCacheEntry{entry})
|
||||
return
|
||||
}
|
||||
|
||||
entries := r.cacheStore()
|
||||
// Replace any exact-token duplicate branch with the new result.
|
||||
for i := 0; i < len(entries); i++ {
|
||||
if entries[i] == nil || !equalTokens(entries[i].Tokens, entry.Tokens) {
|
||||
continue
|
||||
}
|
||||
entries[i].Free()
|
||||
entries = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
|
||||
entries = append([]*HybridCacheEntry{entry}, entries...)
|
||||
if len(entries) > branchLimit {
|
||||
for _, evicted := range entries[branchLimit:] {
|
||||
if evicted != nil {
|
||||
evicted.Free()
|
||||
}
|
||||
}
|
||||
entries = entries[:branchLimit]
|
||||
}
|
||||
r.setCacheStore(entries)
|
||||
}
|
||||
|
||||
func (c *HybridCacheEntry) Clone() *HybridCacheEntry {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
tokens := make([]int32, len(c.Tokens))
|
||||
copy(tokens, c.Tokens)
|
||||
caches := make([]cache.Cache, len(c.Caches))
|
||||
for i, cc := range c.Caches {
|
||||
if cc != nil {
|
||||
caches[i] = cc.Clone()
|
||||
}
|
||||
}
|
||||
return &HybridCacheEntry{
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCacheEntry) LogCache() {
|
||||
if c == nil || len(c.Caches) == 0 {
|
||||
return
|
||||
}
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
k, v := kv.State()
|
||||
if k == nil || v == nil {
|
||||
continue
|
||||
}
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
|
||||
40
x/mlxrunner/cache/cache.go
vendored
40
x/mlxrunner/cache/cache.go
vendored
@@ -3,22 +3,13 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func kvCacheGrowDebugEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Materialize() []*mlx.Array
|
||||
CanTrim() bool
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
@@ -26,19 +17,6 @@ type Cache interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
// CheckpointRecorder is an optional cache capability for recording recurrent
|
||||
// state snapshots at specific token positions.
|
||||
type CheckpointRecorder interface {
|
||||
RecordCheckpoint(pos int)
|
||||
}
|
||||
|
||||
// CheckpointRestorer is an optional cache capability for restoring recurrent
|
||||
// state to a previously recorded checkpoint.
|
||||
type CheckpointRestorer interface {
|
||||
BestCheckpoint(target int) (pos int, ok bool)
|
||||
RestoreCheckpoint(pos int) bool
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
@@ -71,9 +49,6 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
@@ -92,19 +67,6 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.keys != nil && c.keys.Valid() {
|
||||
out = append(out, c.keys)
|
||||
}
|
||||
if c.values != nil && c.values.Valid() {
|
||||
out = append(out, c.values)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *KVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
@@ -228,8 +190,6 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
17
x/mlxrunner/cache/cache_test.go
vendored
17
x/mlxrunner/cache/cache_test.go
vendored
@@ -1,17 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestKVCacheGrowDebugEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
|
||||
if !kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
519
x/mlxrunner/cache/recurrent.go
vendored
519
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,519 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRecurrentCheckpointCount = 32
|
||||
defaultRecurrentCheckpointInterval = 128
|
||||
defaultRecurrentCheckpointMinPos = 16
|
||||
)
|
||||
|
||||
type recurrentCheckpoint struct {
|
||||
pos int
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
}
|
||||
|
||||
type recurrentSlot struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
|
||||
checkpoints []recurrentCheckpoint
|
||||
checkpointSize int
|
||||
checkpointNext int
|
||||
checkpointLastPos int
|
||||
|
||||
refs int
|
||||
}
|
||||
|
||||
func getenvInt(name string, def int) int {
|
||||
if v := os.Getenv(name); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func recurrentCheckpointConfig() (count, interval, minPos int) {
|
||||
count = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINTS", defaultRecurrentCheckpointCount)
|
||||
interval = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", defaultRecurrentCheckpointInterval)
|
||||
minPos = getenvInt("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", defaultRecurrentCheckpointMinPos)
|
||||
|
||||
if count < 0 {
|
||||
count = 0
|
||||
}
|
||||
if interval < 0 {
|
||||
interval = 0
|
||||
}
|
||||
if minPos < 0 {
|
||||
minPos = 0
|
||||
}
|
||||
return count, interval, minPos
|
||||
}
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
slot *recurrentSlot
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
|
||||
checkpointCount int
|
||||
checkpointInterval int
|
||||
checkpointMinPos int
|
||||
}
|
||||
|
||||
func newRecurrentSlot(checkpointCount int) *recurrentSlot {
|
||||
s := &recurrentSlot{
|
||||
refs: 1,
|
||||
checkpointLastPos: -1,
|
||||
}
|
||||
if checkpointCount > 0 {
|
||||
s.checkpoints = make([]recurrentCheckpoint, checkpointCount)
|
||||
for i := range s.checkpoints {
|
||||
s.checkpoints[i].pos = -1
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func retainRecurrentSlot(s *recurrentSlot) *recurrentSlot {
|
||||
if s != nil {
|
||||
s.refs++
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
// Break dependency chains so recurrent state does not retain the full
|
||||
// per-token compute graph over time.
|
||||
snap := mlx.Snapshot(v)
|
||||
mlx.Eval(snap)
|
||||
|
||||
old := *dst
|
||||
*dst = snap
|
||||
mlx.Pin(snap)
|
||||
|
||||
// Release previous cached state root, then recursively free the transient
|
||||
// incoming graph root now that a detached snapshot is retained in cache.
|
||||
if old != nil && old != snap {
|
||||
mlx.Release(old)
|
||||
}
|
||||
if v != snap && v != old {
|
||||
mlx.Release(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
old := *dst
|
||||
*dst = v
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Release(old)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
root = mlx.Contiguous(v, false)
|
||||
}
|
||||
detached := mlx.Detach(root)
|
||||
|
||||
old := *dst
|
||||
*dst = detached
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Release(old)
|
||||
}
|
||||
|
||||
// Intentionally do not force-release root/v here. In the fast path, the detached
|
||||
// handle aliases the same MLX value and may still be lazily computed. Releasing the
|
||||
// source handles can invalidate the cached state before the next eval/sweep point.
|
||||
}
|
||||
|
||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
snap := mlx.Snapshot(a)
|
||||
mlx.Eval(snap)
|
||||
mlx.Pin(snap)
|
||||
return snap
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
count, interval, minPos := recurrentCheckpointConfig()
|
||||
c := &RecurrentCache{
|
||||
slot: newRecurrentSlot(count),
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
checkpointCount: count,
|
||||
checkpointInterval: interval,
|
||||
checkpointMinPos: minPos,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func clonePinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
clone := a.Clone()
|
||||
mlx.Pin(clone)
|
||||
return clone
|
||||
}
|
||||
|
||||
func releaseCheckpointEntry(e *recurrentCheckpoint) {
|
||||
mlx.Release(e.convState, e.deltaState)
|
||||
e.convState, e.deltaState = nil, nil
|
||||
e.pos = -1
|
||||
}
|
||||
|
||||
func releaseRecurrentSlot(s *recurrentSlot) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.refs--
|
||||
if s.refs > 0 {
|
||||
return
|
||||
}
|
||||
mlx.Release(s.convState, s.deltaState)
|
||||
s.convState, s.deltaState = nil, nil
|
||||
for i := range s.checkpoints {
|
||||
releaseCheckpointEntry(&s.checkpoints[i])
|
||||
}
|
||||
s.checkpointSize = 0
|
||||
s.checkpointNext = 0
|
||||
s.checkpointLastPos = -1
|
||||
}
|
||||
|
||||
func cloneRecurrentSlot(src *recurrentSlot) *recurrentSlot {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := &recurrentSlot{
|
||||
checkpointSize: src.checkpointSize,
|
||||
checkpointNext: src.checkpointNext,
|
||||
checkpointLastPos: src.checkpointLastPos,
|
||||
refs: 1,
|
||||
}
|
||||
if src.convState != nil && src.convState.Valid() {
|
||||
dst.convState = snapshotPinned(src.convState)
|
||||
}
|
||||
if src.deltaState != nil && src.deltaState.Valid() {
|
||||
dst.deltaState = snapshotPinned(src.deltaState)
|
||||
}
|
||||
if len(src.checkpoints) > 0 {
|
||||
dst.checkpoints = make([]recurrentCheckpoint, len(src.checkpoints))
|
||||
for i := range src.checkpoints {
|
||||
dst.checkpoints[i].pos = src.checkpoints[i].pos
|
||||
if src.checkpoints[i].pos < 0 {
|
||||
continue
|
||||
}
|
||||
dst.checkpoints[i].convState = snapshotPinned(src.checkpoints[i].convState)
|
||||
dst.checkpoints[i].deltaState = snapshotPinned(src.checkpoints[i].deltaState)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) slotOrInit() *recurrentSlot {
|
||||
if c.slot == nil {
|
||||
c.slot = newRecurrentSlot(c.checkpointCount)
|
||||
}
|
||||
return c.slot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensureWritableSlot() *recurrentSlot {
|
||||
s := c.slotOrInit()
|
||||
if s.refs <= 1 {
|
||||
return s
|
||||
}
|
||||
c.slot = cloneRecurrentSlot(s)
|
||||
s.refs--
|
||||
return c.slot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) pruneCheckpointsAfter(pos int) {
|
||||
s := c.ensureWritableSlot()
|
||||
if len(s.checkpoints) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
last := -1
|
||||
minPos := int(^uint(0) >> 1)
|
||||
minIdx := 0
|
||||
for i := range s.checkpoints {
|
||||
e := &s.checkpoints[i]
|
||||
if e.pos > pos {
|
||||
releaseCheckpointEntry(e)
|
||||
}
|
||||
if e.pos >= 0 {
|
||||
size++
|
||||
if e.pos > last {
|
||||
last = e.pos
|
||||
}
|
||||
if e.pos < minPos {
|
||||
minPos = e.pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.checkpointSize = size
|
||||
s.checkpointLastPos = last
|
||||
if size == 0 {
|
||||
s.checkpointNext = 0
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.checkpointNext = next
|
||||
return
|
||||
}
|
||||
s.checkpointNext = minIdx
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
s := c.slotOrInit()
|
||||
needConv := s.convState == nil || s.convState.DType() != dtype ||
|
||||
s.convState.Dim(0) != batch || s.convState.Dim(1) != c.convTail || s.convState.Dim(2) != c.convDim
|
||||
needDelta := s.deltaState == nil || s.deltaState.DType() != dtype ||
|
||||
s.deltaState.Dim(0) != batch || s.deltaState.Dim(1) != c.numVHeads || s.deltaState.Dim(2) != c.headVDim || s.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
return
|
||||
}
|
||||
|
||||
s = c.ensureWritableSlot()
|
||||
if needConv {
|
||||
c.setStateRaw(&s.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
|
||||
if needDelta {
|
||||
c.setStateRaw(&s.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.slotOrInit().convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateMaterialized(&s.convState, v)
|
||||
}
|
||||
|
||||
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point. The conv-state input is usually a slice view, so request a
|
||||
// compact contiguous copy to avoid pinning the whole source buffer.
|
||||
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateDetached(&s.convState, v, true)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.slotOrInit().deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateMaterialized(&s.deltaState, v)
|
||||
}
|
||||
|
||||
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point.
|
||||
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
|
||||
s := c.ensureWritableSlot()
|
||||
c.setStateDetached(&s.deltaState, v, false)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.ensure(1, mlx.DTypeFloat32)
|
||||
s := c.slotOrInit()
|
||||
return s.convState, s.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
s := c.slot
|
||||
if s == nil {
|
||||
return out
|
||||
}
|
||||
if s.convState != nil && s.convState.Valid() {
|
||||
out = append(out, s.convState)
|
||||
}
|
||||
if s.deltaState != nil && s.deltaState.Valid() {
|
||||
out = append(out, s.deltaState)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) RecordCheckpoint(pos int) {
|
||||
s := c.slot
|
||||
if s == nil || len(s.checkpoints) == 0 || pos <= 0 || pos < c.checkpointMinPos {
|
||||
return
|
||||
}
|
||||
if c.offset != pos {
|
||||
// Checkpoints are keyed by logical token position. Ignore callers with a
|
||||
// mismatched position to avoid restoring inconsistent recurrent state.
|
||||
return
|
||||
}
|
||||
if s.convState == nil || s.deltaState == nil || !s.convState.Valid() || !s.deltaState.Valid() {
|
||||
return
|
||||
}
|
||||
if s.checkpointLastPos == pos {
|
||||
return
|
||||
}
|
||||
if s.checkpointLastPos >= 0 && c.checkpointInterval > 0 && pos-s.checkpointLastPos < c.checkpointInterval {
|
||||
return
|
||||
}
|
||||
if s.refs > 1 {
|
||||
s = c.ensureWritableSlot()
|
||||
}
|
||||
|
||||
idx := s.checkpointNext
|
||||
e := &s.checkpoints[idx]
|
||||
releaseCheckpointEntry(e)
|
||||
e.pos = pos
|
||||
e.convState = clonePinned(s.convState)
|
||||
e.deltaState = clonePinned(s.deltaState)
|
||||
|
||||
s.checkpointNext = (idx + 1) % len(s.checkpoints)
|
||||
if s.checkpointSize < len(s.checkpoints) {
|
||||
s.checkpointSize++
|
||||
}
|
||||
s.checkpointLastPos = pos
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) BestCheckpoint(target int) (pos int, ok bool) {
|
||||
s := c.slot
|
||||
if s == nil {
|
||||
return 0, false
|
||||
}
|
||||
best := -1
|
||||
for i := range s.checkpoints {
|
||||
pos := s.checkpoints[i].pos
|
||||
if pos < 0 || pos > target {
|
||||
continue
|
||||
}
|
||||
if pos > best {
|
||||
best = pos
|
||||
}
|
||||
}
|
||||
if best < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return best, true
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) RestoreCheckpoint(pos int) bool {
|
||||
if pos < 0 {
|
||||
return false
|
||||
}
|
||||
s := c.ensureWritableSlot()
|
||||
for i := range s.checkpoints {
|
||||
e := &s.checkpoints[i]
|
||||
if e.pos != pos {
|
||||
continue
|
||||
}
|
||||
if e.convState == nil || e.deltaState == nil || !e.convState.Valid() || !e.deltaState.Valid() {
|
||||
return false
|
||||
}
|
||||
|
||||
c.setStateRaw(&s.convState, e.convState.Clone())
|
||||
c.setStateRaw(&s.deltaState, e.deltaState.Clone())
|
||||
c.offset = pos
|
||||
c.pruneCheckpointsAfter(pos)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
// Recurrent state is not directly trimmable; callers should use
|
||||
// checkpoint-based restore instead.
|
||||
_ = n
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
slot: retainRecurrentSlot(c.slotOrInit()),
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
checkpointCount: c.checkpointCount,
|
||||
checkpointInterval: c.checkpointInterval,
|
||||
checkpointMinPos: c.checkpointMinPos,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
releaseRecurrentSlot(c.slot)
|
||||
c.slot = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
125
x/mlxrunner/cache/recurrent_cow_test.go
vendored
125
x/mlxrunner/cache/recurrent_cow_test.go
vendored
@@ -1,125 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func requireMLXRuntime(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX runtime unavailable: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRecurrentCache(t *testing.T) *RecurrentCache {
|
||||
t.Helper()
|
||||
requireMLXRuntime(t)
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINTS", "2")
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_INTERVAL", "0")
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_CHECKPOINT_MIN_POS", "0")
|
||||
|
||||
c := NewRecurrentCache(2, 3, 1, 2, 2)
|
||||
_ = c.ConvState(1, mlx.DTypeFloat32)
|
||||
_ = c.DeltaState(1, mlx.DTypeFloat32)
|
||||
return c
|
||||
}
|
||||
|
||||
func TestRecurrentCacheCloneSharesSlotUntilMutation(t *testing.T) {
|
||||
c1 := newTestRecurrentCache(t)
|
||||
t.Cleanup(func() {
|
||||
c1.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
c1.Advance(8)
|
||||
c1.RecordCheckpoint(8)
|
||||
if got, ok := c1.BestCheckpoint(8); !ok || got != 8 {
|
||||
t.Fatalf("BestCheckpoint(8) = (%d, %v), want (8, true)", got, ok)
|
||||
}
|
||||
|
||||
c2 := c1.Clone().(*RecurrentCache)
|
||||
t.Cleanup(func() {
|
||||
c2.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
if c1.slot == nil || c2.slot == nil {
|
||||
t.Fatal("expected non-nil shared slots")
|
||||
}
|
||||
if c1.slot != c2.slot {
|
||||
t.Fatal("clone did not share recurrent slot")
|
||||
}
|
||||
if c1.slot.refs != 2 {
|
||||
t.Fatalf("shared slot refs = %d, want 2", c1.slot.refs)
|
||||
}
|
||||
|
||||
// Read access should not trigger a COW detach.
|
||||
_ = c2.ConvState(1, mlx.DTypeFloat32)
|
||||
_ = c2.DeltaState(1, mlx.DTypeFloat32)
|
||||
if c1.slot != c2.slot {
|
||||
t.Fatal("read access detached shared recurrent slot")
|
||||
}
|
||||
|
||||
// Mutating recurrent state should detach and deep-copy checkpoint metadata.
|
||||
c2.SetConvState(mlx.Zeros(mlx.DTypeFloat32, 1, 2, 3))
|
||||
|
||||
if c1.slot == c2.slot {
|
||||
t.Fatal("SetConvState did not detach shared recurrent slot")
|
||||
}
|
||||
if c1.slot.refs != 1 || c2.slot.refs != 1 {
|
||||
t.Fatalf("post-detach refs = (%d, %d), want (1, 1)", c1.slot.refs, c2.slot.refs)
|
||||
}
|
||||
if len(c1.slot.checkpoints) == 0 || len(c2.slot.checkpoints) == 0 {
|
||||
t.Fatal("expected checkpoint ring to be preserved on detach")
|
||||
}
|
||||
if c1.slot.checkpoints[0].pos != c2.slot.checkpoints[0].pos {
|
||||
t.Fatalf("checkpoint pos mismatch after detach: %d vs %d", c1.slot.checkpoints[0].pos, c2.slot.checkpoints[0].pos)
|
||||
}
|
||||
if c1.slot.checkpoints[0].pos != 8 {
|
||||
t.Fatalf("checkpoint pos = %d, want 8", c1.slot.checkpoints[0].pos)
|
||||
}
|
||||
if c1.slot.checkpoints[0].convState == c2.slot.checkpoints[0].convState {
|
||||
t.Fatal("checkpoint conv state was aliased after COW detach")
|
||||
}
|
||||
if c1.slot.checkpoints[0].deltaState == c2.slot.checkpoints[0].deltaState {
|
||||
t.Fatal("checkpoint delta state was aliased after COW detach")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentCacheFreeKeepsSharedCloneAlive(t *testing.T) {
|
||||
c1 := newTestRecurrentCache(t)
|
||||
c2 := c1.Clone().(*RecurrentCache)
|
||||
t.Cleanup(func() {
|
||||
c1.Free()
|
||||
c2.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
if c2.slot == nil || c2.slot.refs != 2 {
|
||||
t.Fatalf("shared clone refs = %d, want 2", func() int {
|
||||
if c2.slot == nil {
|
||||
return 0
|
||||
}
|
||||
return c2.slot.refs
|
||||
}())
|
||||
}
|
||||
|
||||
c1.Free()
|
||||
|
||||
if c2.slot == nil {
|
||||
t.Fatal("clone slot was cleared after freeing sibling clone")
|
||||
}
|
||||
if c2.slot.refs != 1 {
|
||||
t.Fatalf("clone slot refs after sibling Free = %d, want 1", c2.slot.refs)
|
||||
}
|
||||
if state := c2.ConvState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
|
||||
t.Fatal("clone conv state invalid after freeing sibling clone")
|
||||
}
|
||||
if state := c2.DeltaState(1, mlx.DTypeFloat32); state == nil || !state.Valid() {
|
||||
t.Fatal("clone delta state invalid after freeing sibling clone")
|
||||
}
|
||||
}
|
||||
@@ -1,339 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
cachepkg "github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type fakeCache struct {
|
||||
canTrim bool
|
||||
trims []int
|
||||
freeCall int
|
||||
offset int
|
||||
}
|
||||
|
||||
func (f *fakeCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
|
||||
func (f *fakeCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
|
||||
func (f *fakeCache) Materialize() []*mlx.Array { return nil }
|
||||
func (f *fakeCache) CanTrim() bool { return f.canTrim }
|
||||
func (f *fakeCache) Trim(n int) int {
|
||||
f.trims = append(f.trims, n)
|
||||
f.offset -= n
|
||||
return n
|
||||
}
|
||||
func (f *fakeCache) Clone() cachepkg.Cache { return &fakeCache{canTrim: f.canTrim, offset: f.offset} }
|
||||
func (f *fakeCache) Free() { f.freeCall++ }
|
||||
func (f *fakeCache) Offset() int { return f.offset }
|
||||
func (f *fakeCache) Len() int { return f.offset }
|
||||
|
||||
type fakeCheckpointCache struct {
|
||||
fakeCache
|
||||
bestPos int
|
||||
hasCheckpoint bool
|
||||
restoreCalls []int
|
||||
restoreSuccess bool
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) BestCheckpoint(target int) (int, bool) {
|
||||
if !f.hasCheckpoint || f.bestPos > target {
|
||||
return 0, false
|
||||
}
|
||||
return f.bestPos, true
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) RestoreCheckpoint(pos int) bool {
|
||||
f.restoreCalls = append(f.restoreCalls, pos)
|
||||
if !f.restoreSuccess || pos != f.bestPos {
|
||||
return false
|
||||
}
|
||||
f.offset = pos
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *fakeCheckpointCache) Clone() cachepkg.Cache {
|
||||
clone := *f
|
||||
clone.trims = nil
|
||||
clone.restoreCalls = nil
|
||||
return &clone
|
||||
}
|
||||
|
||||
func TestFindNearestCacheReusesAppendOnlyNonTrimmableCache(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: false, offset: 2}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4})
|
||||
|
||||
if len(gotCaches) != 1 || gotCaches[0] != fc {
|
||||
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
|
||||
}
|
||||
if want := []int32{3, 4}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", fc.freeCall)
|
||||
}
|
||||
if len(fc.trims) != 0 {
|
||||
t.Fatalf("trim calls = %v, want none", fc.trims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheDropsNonTrimmableCacheOnDivergence(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: false, offset: 4}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if gotCaches != nil {
|
||||
t.Fatalf("returned caches = %#v, want nil", gotCaches)
|
||||
}
|
||||
if want := []int32{1, 2, 9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", fc.freeCall)
|
||||
}
|
||||
if len(fc.trims) != 0 {
|
||||
t.Fatalf("trim calls = %v, want none", fc.trims)
|
||||
}
|
||||
if r.cache != nil {
|
||||
t.Fatal("runner cache should be cleared on non-trimmable divergence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheTrimsTrimmableCacheOnDivergence(t *testing.T) {
|
||||
fc := &fakeCache{canTrim: true, offset: 4}
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{fc},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if len(gotCaches) != 1 || gotCaches[0] != fc {
|
||||
t.Fatalf("returned caches = %#v, want original cache", gotCaches)
|
||||
}
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if fc.freeCall != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", fc.freeCall)
|
||||
}
|
||||
if want := []int{2}; !reflect.DeepEqual(fc.trims, want) {
|
||||
t.Fatalf("trim calls = %v, want %v", fc.trims, want)
|
||||
}
|
||||
if want := []int32{1, 2}; !reflect.DeepEqual(r.cache.Tokens, want) {
|
||||
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheRestoresCheckpointForNonTrimmableCaches(t *testing.T) {
|
||||
kv := &fakeCache{canTrim: true, offset: 7}
|
||||
rc1 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
rc2 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
|
||||
Caches: []cachepkg.Cache{kv, rc1, rc2},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
|
||||
|
||||
if len(gotCaches) != 3 {
|
||||
t.Fatalf("returned caches len = %d, want 3", len(gotCaches))
|
||||
}
|
||||
if want := []int32{8}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if want := []int{3}; !reflect.DeepEqual(kv.trims, want) {
|
||||
t.Fatalf("kv trim calls = %v, want %v", kv.trims, want)
|
||||
}
|
||||
if want := []int{4}; !reflect.DeepEqual(rc1.restoreCalls, want) {
|
||||
t.Fatalf("rc1 restore calls = %v, want %v", rc1.restoreCalls, want)
|
||||
}
|
||||
if want := []int{4}; !reflect.DeepEqual(rc2.restoreCalls, want) {
|
||||
t.Fatalf("rc2 restore calls = %v, want %v", rc2.restoreCalls, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.cache.Tokens, want) {
|
||||
t.Fatalf("cached tokens = %v, want %v", r.cache.Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheDropsOnMismatchedCheckpointRestorePoints(t *testing.T) {
|
||||
rc1 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 4,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
rc2 := &fakeCheckpointCache{
|
||||
fakeCache: fakeCache{canTrim: false, offset: 7},
|
||||
bestPos: 3,
|
||||
hasCheckpoint: true,
|
||||
restoreSuccess: true,
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: &CacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4, 5, 6, 7},
|
||||
Caches: []cachepkg.Cache{rc1, rc2},
|
||||
},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 8})
|
||||
|
||||
if gotCaches != nil {
|
||||
t.Fatalf("returned caches = %#v, want nil", gotCaches)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4, 8}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if rc1.freeCall != 1 || rc2.freeCall != 1 {
|
||||
t.Fatalf("free calls = (%d,%d), want (1,1)", rc1.freeCall, rc2.freeCall)
|
||||
}
|
||||
if len(rc1.restoreCalls) != 0 || len(rc2.restoreCalls) != 0 {
|
||||
t.Fatalf("restore calls = (%v,%v), want none", rc1.restoreCalls, rc2.restoreCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheSelectsBestPrefixAcrossBranches(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "4")
|
||||
|
||||
short := &fakeCache{canTrim: true, offset: 2}
|
||||
long := &fakeCache{canTrim: true, offset: 4}
|
||||
shortEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2},
|
||||
Caches: []cachepkg.Cache{short},
|
||||
}
|
||||
longEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{long},
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
cache: shortEntry,
|
||||
caches: []*HybridCacheEntry{shortEntry, longEntry},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 3, 4, 9})
|
||||
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if len(gotCaches) != 1 {
|
||||
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
|
||||
}
|
||||
if gotCaches[0] == long {
|
||||
t.Fatal("expected cloned cache in multi-branch mode, got original branch cache")
|
||||
}
|
||||
if r.cache != longEntry || r.caches[0] != longEntry {
|
||||
t.Fatal("best branch was not promoted to front of cache store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNearestCacheForksBranchWithCloneWhenMultiBranchEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
|
||||
|
||||
base := &fakeCache{canTrim: true, offset: 4}
|
||||
baseEntry := &HybridCacheEntry{
|
||||
Tokens: []int32{1, 2, 3, 4},
|
||||
Caches: []cachepkg.Cache{base},
|
||||
}
|
||||
r := &Runner{
|
||||
cache: baseEntry,
|
||||
caches: []*HybridCacheEntry{baseEntry},
|
||||
}
|
||||
|
||||
gotCaches, gotTokens := r.FindNearestCache([]int32{1, 2, 9})
|
||||
|
||||
if want := []int32{9}; !reflect.DeepEqual(gotTokens, want) {
|
||||
t.Fatalf("tokens left = %v, want %v", gotTokens, want)
|
||||
}
|
||||
if len(gotCaches) != 1 {
|
||||
t.Fatalf("returned caches len = %d, want 1", len(gotCaches))
|
||||
}
|
||||
clone, ok := gotCaches[0].(*fakeCache)
|
||||
if !ok {
|
||||
t.Fatalf("returned cache type = %T, want *fakeCache", gotCaches[0])
|
||||
}
|
||||
if clone == base {
|
||||
t.Fatal("expected branch fork to return a cloned cache")
|
||||
}
|
||||
if len(base.trims) != 0 {
|
||||
t.Fatalf("base branch trim calls = %v, want none", base.trims)
|
||||
}
|
||||
if want := []int{2}; !reflect.DeepEqual(clone.trims, want) {
|
||||
t.Fatalf("forked branch trim calls = %v, want %v", clone.trims, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(baseEntry.Tokens, want) {
|
||||
t.Fatalf("base entry tokens = %v, want %v", baseEntry.Tokens, want)
|
||||
}
|
||||
|
||||
r.InsertCache([]int32{1, 2, 9}, gotCaches)
|
||||
if len(r.caches) != 2 {
|
||||
t.Fatalf("cache store len = %d, want 2", len(r.caches))
|
||||
}
|
||||
if want := []int32{1, 2, 9}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
|
||||
t.Fatalf("new branch tokens = %v, want %v", r.caches[0].Tokens, want)
|
||||
}
|
||||
if want := []int32{1, 2, 3, 4}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
|
||||
t.Fatalf("preserved branch tokens = %v, want %v", r.caches[1].Tokens, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertCacheEvictsOldestBranchWhenStoreFull(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PROMPT_CACHE_BRANCHES", "2")
|
||||
|
||||
f1 := &fakeCache{canTrim: true, offset: 1}
|
||||
f2 := &fakeCache{canTrim: true, offset: 2}
|
||||
f3 := &fakeCache{canTrim: true, offset: 3}
|
||||
r := &Runner{}
|
||||
|
||||
r.InsertCache([]int32{1}, []cachepkg.Cache{f1})
|
||||
r.InsertCache([]int32{1, 2}, []cachepkg.Cache{f2})
|
||||
r.InsertCache([]int32{1, 2, 3}, []cachepkg.Cache{f3})
|
||||
|
||||
if len(r.caches) != 2 {
|
||||
t.Fatalf("cache store len = %d, want 2", len(r.caches))
|
||||
}
|
||||
if f1.freeCall != 1 {
|
||||
t.Fatalf("oldest branch free calls = %d, want 1", f1.freeCall)
|
||||
}
|
||||
if f2.freeCall != 0 || f3.freeCall != 0 {
|
||||
t.Fatalf("unexpected frees for retained branches: f2=%d f3=%d", f2.freeCall, f3.freeCall)
|
||||
}
|
||||
if want := []int32{1, 2, 3}; !reflect.DeepEqual(r.caches[0].Tokens, want) {
|
||||
t.Fatalf("MRU tokens = %v, want %v", r.caches[0].Tokens, want)
|
||||
}
|
||||
if want := []int32{1, 2}; !reflect.DeepEqual(r.caches[1].Tokens, want) {
|
||||
t.Fatalf("LRU tokens = %v, want %v", r.caches[1].Tokens, want)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,4 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
@@ -267,20 +267,3 @@ func LogArrays() {
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
// Release forcibly frees arrays regardless of reference accounting.
|
||||
// Use only for arrays that are known to be unreachable by any live model state.
|
||||
func Release(s ...*Array) (n int) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.pinned = false
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled atomic.Bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
dtype := q.DType()
|
||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
@@ -19,8 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
// Callers may pass optional tensors (e.g. debug-only logprobs) as nil.
|
||||
if output != nil && output.Valid() {
|
||||
if output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -429,32 +378,6 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||
func Snapshot(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("SNAPSHOT")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// CollectReachable collects arrays from v and all transitive graph inputs.
|
||||
func CollectReachable(v any) []*Array {
|
||||
return Collect(v)
|
||||
}
|
||||
|
||||
// Detach returns a new Array handle that shares the same MLX value but does
|
||||
// not retain Go-side graph input references.
|
||||
func Detach(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("DETACH")
|
||||
C.mlx_array_set(&out.ctx, a.ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -6,10 +6,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -17,234 +14,6 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
const defaultRecurrentMaterializeInterval = 64
|
||||
const defaultPipelineTimingEvery = 64
|
||||
|
||||
func prefillChunkSize(lowMemoryDecode bool) int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
if lowMemoryDecode {
|
||||
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
|
||||
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
|
||||
return 32
|
||||
}
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func hasRecurrentCaches(caches []cache.Cache) bool {
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := c.(*cache.RecurrentCache); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// recurrentMaterializeInterval controls periodic recurrent-cache materialization
|
||||
// during async decode. It exists to bound graph/handle growth when using fast
|
||||
// recurrent cache writes; it is primarily a memory/stability tuning knob, not a
|
||||
// throughput knob.
|
||||
func recurrentMaterializeInterval(lowMemoryDecode bool, hasRecurrent bool) int {
|
||||
if lowMemoryDecode || !hasRecurrent {
|
||||
return 0
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultRecurrentMaterializeInterval
|
||||
}
|
||||
|
||||
func mlxDebugMemoryEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
|
||||
}
|
||||
|
||||
// mlxPipelineTimingConfig controls runner-side decode pipeline timing logs. This
|
||||
// is diagnostic-only and intentionally separate from model-specific timing.
|
||||
func mlxPipelineTimingConfig() (enabled bool, every int) {
|
||||
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_TIMING"); ok {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
enabled = parsed
|
||||
}
|
||||
}
|
||||
if !enabled {
|
||||
return false, 0
|
||||
}
|
||||
every = defaultPipelineTimingEvery
|
||||
if v := os.Getenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
every = n
|
||||
}
|
||||
}
|
||||
return true, every
|
||||
}
|
||||
|
||||
// mlxComputeLogprobsEnabled restores the old decode-step logprob normalization
|
||||
// path for profiling/experiments. It is off by default because the MLX runner
|
||||
// does not currently populate Response.Logprobs.
|
||||
func mlxComputeLogprobsEnabled() bool {
|
||||
if v, ok := os.LookupEnv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS"); ok {
|
||||
if enabled, err := strconv.ParseBool(v); err == nil {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
// The MLX runner currently does not populate Response.Logprobs, so skip the
|
||||
// full-vocab logprob normalization path unless explicitly requested for
|
||||
// debugging/experiments.
|
||||
return false
|
||||
}
|
||||
|
||||
type pipelineTiming struct {
|
||||
every int
|
||||
|
||||
stepCalls int
|
||||
stepAsync int
|
||||
stepSync int
|
||||
sampleInts int
|
||||
|
||||
stepTotalDur time.Duration
|
||||
forwardDur time.Duration
|
||||
unembedDur time.Duration
|
||||
sliceDur time.Duration
|
||||
logprobsDur time.Duration
|
||||
sampleDur time.Duration
|
||||
pinSweepDur time.Duration
|
||||
asyncEvalDur time.Duration
|
||||
sampleIntDur time.Duration
|
||||
lastEmitCount int
|
||||
}
|
||||
|
||||
func newPipelineTiming() *pipelineTiming {
|
||||
enabled, every := mlxPipelineTimingConfig()
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
pt := &pipelineTiming{every: every}
|
||||
fmt.Fprintf(os.Stderr, "mlx pipeline timing: enabled every=%d\n", every)
|
||||
return pt
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) recordStep(
|
||||
async bool,
|
||||
total, forward, unembed, slice, logprobs, sample, pinSweep, asyncEval time.Duration,
|
||||
) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
pt.stepCalls++
|
||||
if async {
|
||||
pt.stepAsync++
|
||||
} else {
|
||||
pt.stepSync++
|
||||
}
|
||||
pt.stepTotalDur += total
|
||||
pt.forwardDur += forward
|
||||
pt.unembedDur += unembed
|
||||
pt.sliceDur += slice
|
||||
pt.logprobsDur += logprobs
|
||||
pt.sampleDur += sample
|
||||
pt.pinSweepDur += pinSweep
|
||||
pt.asyncEvalDur += asyncEval
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) recordSampleInt(d time.Duration, decodeCount int) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
pt.sampleInts++
|
||||
pt.sampleIntDur += d
|
||||
pt.maybeEmit(false, decodeCount)
|
||||
}
|
||||
|
||||
func (pt *pipelineTiming) maybeEmit(force bool, decodeCount int) {
|
||||
if pt == nil {
|
||||
return
|
||||
}
|
||||
if !force {
|
||||
if pt.every <= 0 || decodeCount <= 0 || decodeCount%pt.every != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
if pt.lastEmitCount == decodeCount {
|
||||
return
|
||||
}
|
||||
pt.lastEmitCount = decodeCount
|
||||
|
||||
msAvg := func(d time.Duration, n int) float64 {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(d) / float64(n) / float64(time.Millisecond)
|
||||
}
|
||||
stepResidual := pt.stepTotalDur - pt.forwardDur - pt.unembedDur - pt.sliceDur - pt.logprobsDur - pt.sampleDur - pt.pinSweepDur - pt.asyncEvalDur
|
||||
if stepResidual < 0 {
|
||||
stepResidual = 0
|
||||
}
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"mlx pipeline timing: decode=%d step_calls=%d step_async=%d step_sync=%d avg_step_ms=%.2f fwd_ms=%.2f unembed_ms=%.2f slice_ms=%.2f logprobs_ms=%.2f sample_ms=%.2f pin_sweep_ms=%.2f async_eval_ms=%.2f step_residual_ms=%.2f sample_int_ms=%.2f\n",
|
||||
decodeCount,
|
||||
pt.stepCalls,
|
||||
pt.stepAsync,
|
||||
pt.stepSync,
|
||||
msAvg(pt.stepTotalDur, pt.stepCalls),
|
||||
msAvg(pt.forwardDur, pt.stepCalls),
|
||||
msAvg(pt.unembedDur, pt.stepCalls),
|
||||
msAvg(pt.sliceDur, pt.stepCalls),
|
||||
msAvg(pt.logprobsDur, pt.stepCalls),
|
||||
msAvg(pt.sampleDur, pt.stepCalls),
|
||||
msAvg(pt.pinSweepDur, pt.stepCalls),
|
||||
msAvg(pt.asyncEvalDur, pt.stepCalls),
|
||||
msAvg(stepResidual, pt.stepCalls),
|
||||
msAvg(pt.sampleIntDur, pt.sampleInts),
|
||||
)
|
||||
}
|
||||
|
||||
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
|
||||
if usePromptCache {
|
||||
insertCache()
|
||||
logMemory("request_done_cached", -1)
|
||||
return
|
||||
}
|
||||
freeCaches()
|
||||
logMemory("request_done_freed", -1)
|
||||
}
|
||||
|
||||
func recordCacheCheckpoints(caches []cache.Cache, pos int) {
|
||||
if pos <= 0 {
|
||||
return
|
||||
}
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if recorder, ok := c.(cache.CheckpointRecorder); ok {
|
||||
recorder.RecordCheckpoint(pos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func freeOwnedCaches(caches []cache.Cache) {
|
||||
for i, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
c.Free()
|
||||
caches[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
@@ -262,24 +31,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
usePromptCache := true
|
||||
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
|
||||
usePromptCache = false
|
||||
}
|
||||
lowMemoryDecode := !usePromptCache
|
||||
if m, ok := r.Model.(interface{ LowMemoryDecode() bool }); ok {
|
||||
lowMemoryDecode = m.LowMemoryDecode()
|
||||
}
|
||||
prefillChunk := prefillChunkSize(lowMemoryDecode)
|
||||
|
||||
var caches []cache.Cache
|
||||
var tokens []int32
|
||||
if usePromptCache {
|
||||
caches, tokens = r.FindNearestCache(inputs)
|
||||
} else {
|
||||
tokens = inputs
|
||||
}
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
@@ -291,140 +43,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
materializeRecurrentCaches := func() bool {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := c.(*cache.RecurrentCache); !ok {
|
||||
continue
|
||||
}
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return false
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
return true
|
||||
}
|
||||
freeCaches := func() {
|
||||
// Non-prompt-cache requests allocate fresh caches every generation.
|
||||
// Explicitly free cache-owned state (including recurrent checkpoints),
|
||||
// then sweep remaining intermediates.
|
||||
freeOwnedCaches(caches)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
debugMemory := mlxDebugMemoryEnabled()
|
||||
hasRecurrent := hasRecurrentCaches(caches)
|
||||
asyncRecurrentMaterializeEvery := recurrentMaterializeInterval(lowMemoryDecode, hasRecurrent)
|
||||
computeStepLogprobs := mlxComputeLogprobsEnabled()
|
||||
pipelineTiming := newPipelineTiming()
|
||||
logMemory := func(phase string, token int) {
|
||||
if !debugMemory {
|
||||
return
|
||||
}
|
||||
if token >= 0 {
|
||||
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
|
||||
return
|
||||
}
|
||||
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
|
||||
}
|
||||
logMemory("prefill_start", -1)
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
recordCacheCheckpoints(caches, processed+n)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
logMemory("prefill_done", -1)
|
||||
|
||||
step := func(token *mlx.Array, async bool) (*mlx.Array, *mlx.Array) {
|
||||
var t0, t time.Time
|
||||
var forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur time.Duration
|
||||
if pipelineTiming != nil {
|
||||
t0 = time.Now()
|
||||
t = t0
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
if pipelineTiming != nil {
|
||||
forwardDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
logits := r.Model.Unembed(fwd)
|
||||
if pipelineTiming != nil {
|
||||
unembedDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
if pipelineTiming != nil {
|
||||
sliceDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
var logprobs *mlx.Array
|
||||
sampleInput := logits
|
||||
if computeStepLogprobs {
|
||||
logprobs = logits.Subtract(logits.Logsumexp(true))
|
||||
sampleInput = logprobs
|
||||
}
|
||||
if pipelineTiming != nil {
|
||||
logprobsDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
sample := request.Sample(sampleInput)
|
||||
if pipelineTiming != nil {
|
||||
sampleDur = time.Since(t)
|
||||
t = time.Now()
|
||||
}
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
sample := request.Sample(logprobs)
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
if pipelineTiming != nil {
|
||||
pinSweepDur = time.Since(t)
|
||||
}
|
||||
if async {
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
if pipelineTiming != nil {
|
||||
asyncEvalDur = time.Since(t)
|
||||
}
|
||||
}
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.recordStep(async, time.Since(t0), forwardDur, unembedDur, sliceDur, logprobsDur, sampleDur, pinSweepDur, asyncEvalDur)
|
||||
}
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed), !lowMemoryDecode)
|
||||
if lowMemoryDecode {
|
||||
// Materialize cache updates to prevent transform graph growth.
|
||||
materializeCaches()
|
||||
}
|
||||
recordCacheCheckpoints(caches, total)
|
||||
logMemory("decode_init", -1)
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -432,34 +84,20 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
var nextSample, nextLogprobs *mlx.Array
|
||||
if !lowMemoryDecode {
|
||||
nextSample, nextLogprobs = step(sample, true)
|
||||
}
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
logMemory("decode_first_eval", i)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
var intWaitStart time.Time
|
||||
if pipelineTiming != nil {
|
||||
intWaitStart = time.Now()
|
||||
}
|
||||
output := int32(sample.Int())
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.recordSampleInt(time.Since(intWaitStart), len(outputs)+1)
|
||||
}
|
||||
outputs = append(outputs, output)
|
||||
if !lowMemoryDecode {
|
||||
recordCacheCheckpoints(caches, total+len(outputs))
|
||||
}
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
@@ -471,53 +109,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
// For recurrent linear-attention models, avoid async prefetch to reduce
|
||||
// peak memory and clear allocator cache every token.
|
||||
if lowMemoryDecode {
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
if i+1 >= request.Options.MaxTokens {
|
||||
break
|
||||
}
|
||||
mlx.ClearCache()
|
||||
sample, logprobs = step(mlx.FromValues([]int32{output}, 1), false)
|
||||
// Materialize cache updates to avoid unbounded transform chains.
|
||||
materializeCaches()
|
||||
recordCacheCheckpoints(caches, total+len(outputs))
|
||||
if i%32 == 0 {
|
||||
logMemory("decode_lowmem_step", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
if asyncRecurrentMaterializeEvery > 0 && (i+1)%asyncRecurrentMaterializeEvery == 0 {
|
||||
if materializeRecurrentCaches() {
|
||||
mlx.Sweep()
|
||||
logMemory("decode_async_recurrent_materialize", i)
|
||||
}
|
||||
}
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
if i%64 == 0 {
|
||||
logMemory("decode_async_step", i)
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
if pipelineTiming != nil {
|
||||
pipelineTiming.maybeEmit(true, len(outputs))
|
||||
}
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
finalizeRequestCaches(usePromptCache,
|
||||
func() { r.InsertCache(append(inputs, outputs...), caches) },
|
||||
freeCaches,
|
||||
logMemory,
|
||||
)
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
@@ -526,6 +129,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type stubCache struct {
|
||||
freeCalls int
|
||||
}
|
||||
|
||||
func (s *stubCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return keys, values }
|
||||
func (s *stubCache) State() (*mlx.Array, *mlx.Array) { return nil, nil }
|
||||
func (s *stubCache) Materialize() []*mlx.Array { return nil }
|
||||
func (s *stubCache) CanTrim() bool { return true }
|
||||
func (s *stubCache) Trim(int) int { return 0 }
|
||||
func (s *stubCache) Clone() cache.Cache { return s }
|
||||
func (s *stubCache) Free() { s.freeCalls++ }
|
||||
func (s *stubCache) Offset() int { return 0 }
|
||||
func (s *stubCache) Len() int { return 0 }
|
||||
|
||||
func TestPrefillChunkSize(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
|
||||
if got := prefillChunkSize(false); got != 2<<10 {
|
||||
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 32 {
|
||||
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
|
||||
if got := prefillChunkSize(false); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXDebugMemoryEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
|
||||
if mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
|
||||
if !mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasRecurrentCaches(t *testing.T) {
|
||||
if hasRecurrentCaches(nil) {
|
||||
t.Fatal("hasRecurrentCaches(nil) = true, want false")
|
||||
}
|
||||
|
||||
if hasRecurrentCaches([]cache.Cache{cache.NewKVCache()}) {
|
||||
t.Fatal("hasRecurrentCaches(kv-only) = true, want false")
|
||||
}
|
||||
|
||||
rc := cache.NewRecurrentCache(4, 8, 2, 16, 8)
|
||||
if !hasRecurrentCaches([]cache.Cache{cache.NewKVCache(), rc}) {
|
||||
t.Fatal("hasRecurrentCaches(mixed) = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentMaterializeInterval(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "")
|
||||
|
||||
if got := recurrentMaterializeInterval(true, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(lowmem=true, recurrent=true) = %d, want 0", got)
|
||||
}
|
||||
if got := recurrentMaterializeInterval(false, false); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(lowmem=false, recurrent=false) = %d, want 0", got)
|
||||
}
|
||||
if got := recurrentMaterializeInterval(false, true); got != defaultRecurrentMaterializeInterval {
|
||||
t.Fatalf("recurrentMaterializeInterval(default) = %d, want %d", got, defaultRecurrentMaterializeInterval)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "16")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 16 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=16) = %d, want 16", got)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "0")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=0) = %d, want 0", got)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_RECURRENT_MATERIALIZE_INTERVAL", "-1")
|
||||
if got := recurrentMaterializeInterval(false, true); got != 0 {
|
||||
t.Fatalf("recurrentMaterializeInterval(env=-1) = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXPipelineTimingConfig(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "")
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "")
|
||||
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
|
||||
t.Fatalf("mlxPipelineTimingConfig() = (%v, %d), want (false, 0)", enabled, every)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "1")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled default) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "16")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != 16 {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled env=16) = (%v, %d), want (true, 16)", enabled, every)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING_EVERY", "0")
|
||||
if enabled, every := mlxPipelineTimingConfig(); !enabled || every != defaultPipelineTimingEvery {
|
||||
t.Fatalf("mlxPipelineTimingConfig(enabled env=0) = (%v, %d), want (true, %d)", enabled, every, defaultPipelineTimingEvery)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_TIMING", "0")
|
||||
if enabled, every := mlxPipelineTimingConfig(); enabled || every != 0 {
|
||||
t.Fatalf("mlxPipelineTimingConfig(disabled) = (%v, %d), want (false, 0)", enabled, every)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXComputeLogprobsEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "")
|
||||
if mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "1")
|
||||
if !mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = false with env=1, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_PIPELINE_COMPUTE_LOGPROBS", "0")
|
||||
if mlxComputeLogprobsEnabled() {
|
||||
t.Fatal("mlxComputeLogprobsEnabled() = true with env=0, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
true,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 1 {
|
||||
t.Fatalf("insert calls = %d, want 1", insertCalls)
|
||||
}
|
||||
if freeCalls != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_cached" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
false,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 0 {
|
||||
t.Fatalf("insert calls = %d, want 0", insertCalls)
|
||||
}
|
||||
if freeCalls != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_freed" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFreeOwnedCaches(t *testing.T) {
|
||||
a := &stubCache{}
|
||||
b := &stubCache{}
|
||||
caches := []cache.Cache{a, nil, b}
|
||||
|
||||
freeOwnedCaches(caches)
|
||||
|
||||
if a.freeCalls != 1 {
|
||||
t.Fatalf("a free calls = %d, want 1", a.freeCalls)
|
||||
}
|
||||
if b.freeCalls != 1 {
|
||||
t.Fatalf("b free calls = %d, want 1", b.freeCalls)
|
||||
}
|
||||
if caches[0] != nil || caches[2] != nil {
|
||||
t.Fatalf("cache entries not nilled after free: %#v", caches)
|
||||
}
|
||||
}
|
||||
@@ -62,39 +62,6 @@ type Runner struct {
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
caches []*HybridCacheEntry
|
||||
}
|
||||
|
||||
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
|
||||
if len(tensors) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seen := make(map[*mlx.Array]bool, len(tensors))
|
||||
toRelease := make([]*mlx.Array, 0, len(tensors))
|
||||
for name, arr := range tensors {
|
||||
if arr == nil || !arr.Valid() {
|
||||
delete(tensors, name)
|
||||
continue
|
||||
}
|
||||
if keep != nil {
|
||||
if _, ok := keep[arr]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(tensors, name)
|
||||
if seen[arr] {
|
||||
continue
|
||||
}
|
||||
seen[arr] = true
|
||||
toRelease = append(toRelease, arr)
|
||||
}
|
||||
|
||||
if len(toRelease) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return len(toRelease), mlx.Release(toRelease...)
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -118,33 +85,9 @@ func (r *Runner) Load(modelName string) error {
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
|
||||
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
return err
|
||||
}
|
||||
|
||||
// Materialize model-owned roots before releasing source tensor handles, then
|
||||
// pin only those roots. This avoids retaining large load-time intermediates
|
||||
// while still protecting shared model tensors from Sweep.
|
||||
roots := mlx.Collect(m)
|
||||
mlx.Eval(roots...)
|
||||
mlx.Pin(roots...)
|
||||
|
||||
keep := make(map[*mlx.Array]struct{})
|
||||
for _, arr := range roots {
|
||||
if arr != nil && arr.Valid() {
|
||||
keep[arr] = struct{}{}
|
||||
}
|
||||
}
|
||||
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
|
||||
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
return nil
|
||||
|
||||
@@ -15,40 +15,6 @@ type LinearLayer interface {
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// Conv1d applies 1D convolution over NLC input.
|
||||
type Conv1d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Stride int32
|
||||
Padding int32
|
||||
Dilation int32
|
||||
Groups int32
|
||||
}
|
||||
|
||||
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||
if stride <= 0 {
|
||||
stride = 1
|
||||
}
|
||||
if dilation <= 0 {
|
||||
dilation = 1
|
||||
}
|
||||
if groups <= 0 {
|
||||
groups = 1
|
||||
}
|
||||
return &Conv1d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
Dilation: dilation,
|
||||
Groups: groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
type Linear struct {
|
||||
Weight *mlx.Array
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,206 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantContainer string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
key: "model.embed_tokens.weight",
|
||||
wantContainer: "",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model with inner model",
|
||||
key: "model.language_model.model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model without inner model",
|
||||
key: "model.language_model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||
tt.key: dummy,
|
||||
})
|
||||
|
||||
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||
t.Fatalf(
|
||||
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||
layout.containerPrefix,
|
||||
layout.modelPrefix,
|
||||
tt.wantContainer,
|
||||
tt.wantModel,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRuntimeToggles(t *testing.T) {
|
||||
m := &Model{}
|
||||
if m.DisablePromptCache() {
|
||||
t.Fatal("DisablePromptCache() = true, want false")
|
||||
}
|
||||
if m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = true, want false")
|
||||
}
|
||||
if !m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = false, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "0")
|
||||
if m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_LOW_MEMORY_DECODE", "1")
|
||||
if !m.LowMemoryDecode() {
|
||||
t.Fatal("LowMemoryDecode() = false with env override 1, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "0")
|
||||
if m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_ENABLE_COMPILE", "1")
|
||||
if !m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = false with env override, want true")
|
||||
}
|
||||
|
||||
if !qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = false, want true")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "0")
|
||||
if qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = true with env override 0, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_QWEN3_5_FAST_RECURRENT_WRITE", "1")
|
||||
if !qwen35FastRecurrentWrite() {
|
||||
t.Fatal("qwen35FastRecurrentWrite() = false with env override 1, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
Reference in New Issue
Block a user