mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 18:46:44 -05:00
Compare commits
3 Commits
main
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e72e85da6a | ||
|
|
56f4135c3c | ||
|
|
80a1376f23 |
@@ -320,7 +320,7 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &lfm2Model{}
|
||||
case "Lfm2VlForConditionalGeneration":
|
||||
conv = &lfm2VLTextModel{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
@@ -13,8 +14,21 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
type qwen3NextRopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
type qwen3NextRopeParams struct {
|
||||
MRopeInterleaved bool `json:"mrope_interleaved"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
}
|
||||
|
||||
type qwen3NextTextConfig struct {
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
NormTopkProb *bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
@@ -43,67 +58,213 @@ type qwen3NextModel struct {
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
|
||||
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
type qwen3NextVisionConfig struct {
|
||||
Depth uint32 `json:"depth"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
InChannels uint32 `json:"in_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
|
||||
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
qwen3NextTextConfig
|
||||
|
||||
TextConfig *qwen3NextTextConfig `json:"text_config"`
|
||||
VisionModel qwen3NextVisionConfig `json:"vision_config"`
|
||||
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VisionStartTokenID uint32 `json:"vision_start_token_id"`
|
||||
VisionEndTokenID uint32 `json:"vision_end_token_id"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen3next: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
|
||||
if q.TextConfig != nil {
|
||||
q.qwen3NextTextConfig = *q.TextConfig
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
if q.RopeTheta == 0 {
|
||||
q.RopeTheta = q.RopeParameters.RopeTheta
|
||||
}
|
||||
if q.PartialRotaryFactor == 0 {
|
||||
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
|
||||
}
|
||||
|
||||
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
|
||||
q.RopeScaling.Type = q.RopeParameters.RopeType
|
||||
}
|
||||
|
||||
// Pull vision preprocessing fields when present.
|
||||
if q.VisionModel.Depth > 0 {
|
||||
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
|
||||
var pre struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
if json.Unmarshal(bts, &pre) == nil {
|
||||
if q.VisionModel.PatchSize == 0 {
|
||||
q.VisionModel.PatchSize = pre.PatchSize
|
||||
}
|
||||
if q.VisionModel.TemporalPatchSize == 0 {
|
||||
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
|
||||
}
|
||||
if q.VisionModel.SpatialMergeSize == 0 {
|
||||
q.VisionModel.SpatialMergeSize = pre.MergeSize
|
||||
}
|
||||
if q.VisionModel.Size.ShortestEdge == 0 {
|
||||
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||
}
|
||||
if q.VisionModel.Size.LongestEdge == 0 {
|
||||
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
|
||||
}
|
||||
if len(q.VisionModel.ImageMean) == 0 {
|
||||
q.VisionModel.ImageMean = pre.ImageMean
|
||||
}
|
||||
if len(q.VisionModel.ImageStd) == 0 {
|
||||
q.VisionModel.ImageStd = pre.ImageStd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen35: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen35: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen35: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen35: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen35: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen35: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen35: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if _, err := q.kvHeadCounts(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
|
||||
if len(q.LayerTypes) > 0 {
|
||||
kv := make([]uint32, q.NumHiddenLayers)
|
||||
hasFull := false
|
||||
hasRecurrent := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
layerType := ""
|
||||
if i < uint32(len(q.LayerTypes)) {
|
||||
layerType = q.LayerTypes[i]
|
||||
}
|
||||
if layerType == "full_attention" {
|
||||
kv[i] = q.NumKeyValueHeads
|
||||
hasFull = true
|
||||
} else {
|
||||
hasRecurrent = true
|
||||
}
|
||||
}
|
||||
if !hasFull || !hasRecurrent {
|
||||
return nil, fmt.Errorf("qwen35: layer_types must include both full_attention and linear_attention")
|
||||
}
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return nil, fmt.Errorf("qwen35: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return nil, fmt.Errorf("qwen35: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
kv := make([]uint32, q.NumHiddenLayers)
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
kv[i] = q.NumKeyValueHeads
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen35: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) ropeSections() []int32 {
|
||||
if len(q.RopeParameters.MropeSection) > 0 {
|
||||
return q.RopeParameters.MropeSection
|
||||
}
|
||||
return q.RopeScaling.MropeSection
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) shouldReorderVHeads() bool {
|
||||
modelType := strings.ToLower(q.ModelType)
|
||||
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, arch := range q.Architectures {
|
||||
arch = strings.ToLower(arch)
|
||||
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Default to qwen3.5 layout for all other qwen3next-family imports.
|
||||
return true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
|
||||
arch := "qwen35"
|
||||
if q.NumExperts > 0 {
|
||||
arch = "qwen35moe"
|
||||
}
|
||||
kv["general.architecture"] = arch
|
||||
kv["tokenizer.ggml.pre"] = "qwen35"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if sections := q.ropeSections(); len(sections) > 0 {
|
||||
kv["mrope_sections"] = sections
|
||||
kv["rope.mrope_section"] = sections
|
||||
kv["rope.dimension_sections"] = sections
|
||||
}
|
||||
if q.RopeParameters.MRopeInterleaved {
|
||||
kv["rope.mrope_interleaved"] = true
|
||||
}
|
||||
|
||||
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.NormTopkProb != nil {
|
||||
kv["norm_top_k_prob"] = *q.NormTopkProb
|
||||
}
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
if q.shouldReorderVHeads() {
|
||||
kv["ssm.v_head_reordered"] = true
|
||||
}
|
||||
if q.FullAttentionInterval > 0 {
|
||||
kv["full_attention_interval"] = q.FullAttentionInterval
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
if headCounts, err := q.kvHeadCounts(); err == nil {
|
||||
kv["attention.head_count_kv"] = headCounts
|
||||
}
|
||||
|
||||
if q.VisionModel.Depth > 0 {
|
||||
kv["vision.block_count"] = q.VisionModel.Depth
|
||||
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
|
||||
kv["vision.num_channels"] = q.VisionModel.InChannels
|
||||
if q.VisionModel.PatchSize > 0 {
|
||||
kv["vision.patch_size"] = q.VisionModel.PatchSize
|
||||
}
|
||||
if q.VisionModel.SpatialMergeSize > 0 {
|
||||
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
|
||||
}
|
||||
if q.VisionModel.RMSNormEps > 0 {
|
||||
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
|
||||
}
|
||||
if q.VisionModel.RopeTheta > 0 {
|
||||
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
|
||||
}
|
||||
if q.VisionModel.TemporalPatchSize > 0 {
|
||||
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
|
||||
}
|
||||
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
|
||||
if q.VisionModel.Size.ShortestEdge > 0 {
|
||||
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
|
||||
}
|
||||
if q.VisionModel.Size.LongestEdge > 0 {
|
||||
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
|
||||
}
|
||||
if len(q.VisionModel.ImageMean) > 0 {
|
||||
kv["vision.image_mean"] = q.VisionModel.ImageMean
|
||||
}
|
||||
if len(q.VisionModel.ImageStd) > 0 {
|
||||
kv["vision.image_std"] = q.VisionModel.ImageStd
|
||||
}
|
||||
}
|
||||
|
||||
if q.ImageTokenID > 0 {
|
||||
kv["image_token_id"] = q.ImageTokenID
|
||||
}
|
||||
if q.VisionStartTokenID > 0 {
|
||||
kv["vision_start_token_id"] = q.VisionStartTokenID
|
||||
}
|
||||
if q.VisionEndTokenID > 0 {
|
||||
kv["vision_end_token_id"] = q.VisionEndTokenID
|
||||
}
|
||||
|
||||
return kv
|
||||
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
|
||||
out = append(out, slices.Collect(splitDim(t, 1,
|
||||
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
|
||||
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
|
||||
))...)
|
||||
|
||||
case strings.Contains(name, ".mlp.experts.down_proj"):
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||
))...)
|
||||
|
||||
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
t.SetRepacker(q.repackSSMA())
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackAttnQKV())
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_dt"):
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
if q.shouldReorderVHeads() {
|
||||
// HF out_proj layout is [out_features, in_features]; reorder columns.
|
||||
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
|
||||
}
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
if q.shouldReorderVHeads() {
|
||||
t.SetRepacker(q.repackConv1D())
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() {
|
||||
return data, nil
|
||||
}
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackAttnQKV() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() || len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
rows := int(shape[0])
|
||||
cols := int(shape[1])
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numV := int(q.LinearNumValueHeads)
|
||||
headK := int(q.LinearKeyHeadDim)
|
||||
headV := int(q.LinearValueHeadDim)
|
||||
qDim := headK * numK
|
||||
kDim := headK * numK
|
||||
vDim := headV * numV
|
||||
qkvDim := qDim + kDim + vDim
|
||||
|
||||
switch {
|
||||
case rows == qkvDim:
|
||||
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
|
||||
// reorder only V rows from grouped -> tiled head layout.
|
||||
out := make([]float32, len(data))
|
||||
qkRows := qDim + kDim
|
||||
qkSize := qkRows * cols
|
||||
copy(out[:qkSize], data[:qkSize])
|
||||
|
||||
vStart := qkSize
|
||||
vEnd := vStart + vDim*cols
|
||||
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
copy(out[vEnd:], data[vEnd:])
|
||||
return out, nil
|
||||
|
||||
case cols == qkvDim:
|
||||
// Fallback for already-transposed [in_features, out_features] tensors.
|
||||
out := make([]float32, len(data))
|
||||
copy(out, data)
|
||||
for r := range rows {
|
||||
base := r * cols
|
||||
vStart := base + qDim + kDim
|
||||
vEnd := vStart + vDim
|
||||
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
}
|
||||
return out, nil
|
||||
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackConv1D() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if !q.shouldReorderVHeads() {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
normShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
normShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
normShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
if len(normShape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
rows := int(normShape[0])
|
||||
cols := int(normShape[1])
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numV := int(q.LinearNumValueHeads)
|
||||
headK := int(q.LinearKeyHeadDim)
|
||||
headV := int(q.LinearValueHeadDim)
|
||||
qkChannels := 2 * headK * numK
|
||||
totalChannels := qkChannels + headV*numV
|
||||
if qkChannels <= 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case rows == totalChannels:
|
||||
// HF layout after squeeze: [channels, kernel]
|
||||
out := make([]float32, len(data))
|
||||
prefix := qkChannels * cols
|
||||
copy(out[:prefix], data[:prefix])
|
||||
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[prefix:], reorderedV)
|
||||
return out, nil
|
||||
case cols == totalChannels:
|
||||
// Fallback for transposed [kernel, channels]
|
||||
out := make([]float32, len(data))
|
||||
copy(out, data)
|
||||
vChannels := totalChannels - qkChannels
|
||||
for r := range rows {
|
||||
base := r * cols
|
||||
vStart := base + qkChannels
|
||||
vEnd := vStart + vChannels
|
||||
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(out[vStart:vEnd], reorderedV)
|
||||
}
|
||||
return out, nil
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackSSMA() Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
if !q.shouldReorderVHeads() {
|
||||
return result, nil
|
||||
}
|
||||
numK := int(q.LinearNumKeyHeads)
|
||||
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
|
||||
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
if dim < 0 {
|
||||
dim += len(dims)
|
||||
}
|
||||
if dim < 0 || dim >= len(dims) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
expected := numKHeads * numVPerK * headDim
|
||||
if dims[dim] != expected {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
newShape := make([]int, 0, len(dims)+2)
|
||||
newShape = append(newShape, dims[:dim]...)
|
||||
newShape = append(newShape, numKHeads, numVPerK, headDim)
|
||||
newShape = append(newShape, dims[dim+1:]...)
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := tt.Reshape(newShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perm := make([]int, len(newShape))
|
||||
for i := range perm {
|
||||
perm[i] = i
|
||||
}
|
||||
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
|
||||
|
||||
tt, err := tensor.Transpose(tt, perm...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
total := 1
|
||||
for _, d := range dims {
|
||||
total *= d
|
||||
}
|
||||
if err := tt.Reshape(total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Vision
|
||||
"model.visual", "v",
|
||||
"patch_embed.proj", "patch_embed",
|
||||
"blocks", "blk",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"deepstack_merger_list", "deepstack_merger",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
// Linear attention (legacy qwen3next)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
|
||||
// Linear attention (qwen35)
|
||||
"linear_attn.in_proj_qkv", "attn_qkv",
|
||||
"linear_attn.in_proj_z", "attn_gate",
|
||||
"linear_attn.in_proj_a", "ssm_alpha",
|
||||
"linear_attn.in_proj_b", "ssm_beta",
|
||||
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
// MoE
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
// Dense FFN
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
563
convert/convert_qwen3next_test.go
Normal file
563
convert/convert_qwen3next_test.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
|
||||
t.Helper()
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tensor.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
numel := 1
|
||||
for _, d := range tensor.Shape {
|
||||
numel *= int(d)
|
||||
}
|
||||
|
||||
values := make([]float32, numel)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
}
|
||||
|
||||
if m.shouldReorderVHeads() {
|
||||
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
Architectures: []string{"Qwen3NextForCausalLM"},
|
||||
},
|
||||
}
|
||||
|
||||
if m.shouldReorderVHeads() {
|
||||
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextKVLegacyConfig(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 8192,
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 2048,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 64,
|
||||
RopeTheta: 1_000_000,
|
||||
RMSNormEPS: 1e-6,
|
||||
|
||||
NumExperts: 8,
|
||||
NumExpertsPerToken: 2,
|
||||
NormTopkProb: boolPtr(true),
|
||||
MoEIntermediateSize: 256,
|
||||
SharedExpertIntermSize: 512,
|
||||
|
||||
FullAttentionInterval: 2,
|
||||
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 64,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 64,
|
||||
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
|
||||
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if _, ok := kv["ssm.v_head_reordered"]; ok {
|
||||
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
|
||||
}
|
||||
if got, want := kv["norm_top_k_prob"], true; got != want {
|
||||
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 4096,
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 2048,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 64,
|
||||
RopeTheta: 1_000_000,
|
||||
RMSNormEPS: 1e-6,
|
||||
NumExperts: 8,
|
||||
NumExpertsPerToken: 2,
|
||||
FullAttentionInterval: 2,
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 64,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 64,
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if _, ok := kv["norm_top_k_prob"]; ok {
|
||||
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35KVFromTextConfig(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
TextConfig: &qwen3NextTextConfig{
|
||||
MaxPositionEmbeddings: 16384,
|
||||
HiddenSize: 1024,
|
||||
NumHiddenLayers: 4,
|
||||
IntermediateSize: 4096,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
HeadDim: 128,
|
||||
RMSNormEPS: 1e-6,
|
||||
|
||||
LayerTypes: []string{
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
},
|
||||
|
||||
LinearConvKernelDim: 4,
|
||||
LinearKeyHeadDim: 128,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 128,
|
||||
|
||||
RopeParameters: qwen3NextRopeParams{
|
||||
MRopeInterleaved: true,
|
||||
MropeSection: []int32{11, 11, 10},
|
||||
RopeType: "default",
|
||||
RopeTheta: 10_000_000,
|
||||
PartialRotaryFactor: 0.25,
|
||||
},
|
||||
},
|
||||
VisionModel: qwen3NextVisionConfig{
|
||||
Depth: 2,
|
||||
HiddenSize: 128,
|
||||
NumHeads: 4,
|
||||
InChannels: 3,
|
||||
PatchSize: 16,
|
||||
SpatialMergeSize: 2,
|
||||
RMSNormEps: 1e-6,
|
||||
RopeTheta: 10_000,
|
||||
TemporalPatchSize: 2,
|
||||
DeepstackVisualIndexes: []int32{1},
|
||||
},
|
||||
ImageTokenID: 1001,
|
||||
VisionStartTokenID: 1002,
|
||||
VisionEndTokenID: 1003,
|
||||
}
|
||||
m.VisionModel.Size.ShortestEdge = 224
|
||||
m.VisionModel.Size.LongestEdge = 4096
|
||||
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
|
||||
|
||||
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "qwen35"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
|
||||
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
|
||||
}
|
||||
|
||||
mrope, ok := kv["mrope_sections"].([]int32)
|
||||
if !ok {
|
||||
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
|
||||
}
|
||||
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
|
||||
}
|
||||
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
|
||||
if !ok {
|
||||
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
|
||||
}
|
||||
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
|
||||
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
|
||||
}
|
||||
|
||||
if got, want := kv["vision.block_count"], uint32(2); got != want {
|
||||
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3NextReplacements(t *testing.T) {
|
||||
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
|
||||
|
||||
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
|
||||
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
|
||||
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
|
||||
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersVHeads(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_gate.weight",
|
||||
shape: []uint64{4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearKeyHeadDim: 1,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
|
||||
data: []float32{
|
||||
0, 1, // q0
|
||||
2, 3, // q1
|
||||
4, 5, // k0
|
||||
6, 7, // k1
|
||||
10, 11, // v(k0,v0)
|
||||
12, 13, // v(k0,v1)
|
||||
20, 21, // v(k1,v0)
|
||||
22, 23, // v(k1,v1)
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
10, 11, 20, 21, 12, 13, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_out.weight",
|
||||
shape: []uint64{2, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_beta.weight",
|
||||
shape: []uint64{4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_5",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearKeyHeadDim: 1,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_conv1d.weight",
|
||||
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
|
||||
data: []float32{
|
||||
0, 1, // q0
|
||||
2, 3, // q1
|
||||
4, 5, // k0
|
||||
6, 7, // k1
|
||||
10, 11, // v(k0,v0)
|
||||
12, 13, // v(k0,v1)
|
||||
20, 21, // v(k1,v0)
|
||||
22, 23, // v(k1,v1)
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
10, 11, 20, 21, 12, 13, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
ModelParameters: ModelParameters{
|
||||
ModelType: "qwen3_next",
|
||||
},
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.attn_gate.weight",
|
||||
shape: []uint64{4, 1},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35MoePackedExperts(t *testing.T) {
|
||||
m := &qwen3NextModel{
|
||||
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||
NumHiddenLayers: 1,
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.mlp.experts.gate_up_proj",
|
||||
shape: []uint64{2, 4, 3},
|
||||
data: []float32{
|
||||
0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8,
|
||||
9, 10, 11,
|
||||
12, 13, 14,
|
||||
15, 16, 17,
|
||||
18, 19, 20,
|
||||
21, 22, 23,
|
||||
},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.mlp.experts.down_proj",
|
||||
shape: []uint64{2, 5, 3},
|
||||
data: make([]float32, 2*5*3),
|
||||
},
|
||||
})
|
||||
|
||||
get := func(name string) *ggml.Tensor {
|
||||
for _, tensor := range out {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gate := get("blk.0.ffn_gate_exps.weight")
|
||||
if gate == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
|
||||
}
|
||||
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := readTensorData(t, gate), []float32{
|
||||
0, 1, 2, 3, 4, 5,
|
||||
12, 13, 14, 15, 16, 17,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected gate values: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
up := get("blk.0.ffn_up_exps.weight")
|
||||
if up == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
|
||||
}
|
||||
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected up shape: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := readTensorData(t, up), []float32{
|
||||
6, 7, 8, 9, 10, 11,
|
||||
18, 19, 20, 21, 22, 23,
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected up values: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
down := get("blk.0.ffn_down_exps.weight")
|
||||
if down == nil {
|
||||
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
|
||||
}
|
||||
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected down shape: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
|
||||
m := &qwen3NextModel{}
|
||||
|
||||
out := m.Tensors([]Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ffn_gate_inp_shexp.weight",
|
||||
shape: []uint64{1, 4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||
}
|
||||
|
||||
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
t.Pre = "deepseek-coder"
|
||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||
t.Pre = "qwen2"
|
||||
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
|
||||
t.Pre = "qwen35"
|
||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||
// noop, empty pretokenizer
|
||||
default:
|
||||
@@ -171,6 +173,13 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Match upstream behavior: prefer chat_template.jinja when present.
|
||||
if bts, err := fs.ReadFile(fsys, "chat_template.jinja"); err == nil {
|
||||
t.Template = string(bts)
|
||||
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -79,6 +79,21 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Template: "<default template>",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat template jinja overrides tokenizer config",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{}`),
|
||||
"tokenizer_config.json": strings.NewReader(`{
|
||||
"chat_template": "<template from tokenizer config>"
|
||||
}`),
|
||||
"chat_template.jinja": strings.NewReader("<template from jinja>"),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||
Pre: "default",
|
||||
Template: "<template from jinja>",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "added tokens",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
@@ -386,6 +401,28 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen35 pretokenizer",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"pre_tokenizer": {
|
||||
"type": "Sequence",
|
||||
"pretokenizers": [
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||
Pre: "qwen35",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -290,6 +290,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen35", "qwen35moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
@@ -868,7 +869,12 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
arch := f.KV().Architecture()
|
||||
if slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
|
||||
return true
|
||||
}
|
||||
|
||||
if slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -892,6 +898,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen35", "qwen35moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointCount = 24
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
DefaultCheckpointInterval = int32(1664)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
@@ -211,6 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
arch := f.KV().Architecture()
|
||||
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
@@ -220,6 +221,9 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
} else {
|
||||
flashAttention = ml.FlashAttentionDisabled
|
||||
}
|
||||
} else if fa && slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
|
||||
// Keep FA explicit for qwen35-family models to avoid architecture fallback to Auto.
|
||||
flashAttention = ml.FlashAttentionEnabled
|
||||
}
|
||||
|
||||
if kvct != "" {
|
||||
|
||||
@@ -2,595 +2,58 @@ package qwen3next
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for full attention layers
|
||||
// - per-sequence conv state for linear attention layers
|
||||
// - per-sequence delta state for linear attention layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
||||
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int // convKernelSize - 1
|
||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
||||
|
||||
// Delta state dimensions
|
||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer delta state buffers (allocated lazily)
|
||||
deltaCtxs map[int]ml.Context
|
||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointDeltaCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(
|
||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||
convDim, convChannels, deltaStateSize int,
|
||||
) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
convDim: convDim,
|
||||
convChannels: convChannels,
|
||||
deltaStateSize: deltaStateSize,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
deltaCtxs: make(map[int]ml.Context),
|
||||
deltaStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: checkpointCountDefault,
|
||||
checkpointMinPos: checkpointMinPosDefault,
|
||||
checkpointInterval: checkpointIntervalDefault,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: deltaStateSize,
|
||||
CheckpointLogPrefix: "qwen3next",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.deltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointDeltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
c.reserveCheckpoints = true
|
||||
c.planCheckpoints(batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero state for newly allocated slots
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = false
|
||||
c.planCheckpoints(batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Zero conv states
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// Zero delta states
|
||||
if len(c.deltaStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
||||
for _, buf := range c.deltaStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.deltaStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// Copy-on-write for recurrent state
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
// If the recurrent slot is shared, detach it before applying a restore.
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
// Removal invalidates recurrent state
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.deltaStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent delta state must stay in F32.
|
||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
||||
c.deltaStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
||||
// DeltaState returns the delta state for current batch sequences as
|
||||
// [headVDim, headVDim*numVHeads, nSeqs].
|
||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
||||
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
|
||||
}
|
||||
|
||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
|
||||
// Keep qwen3next behavior for partial mid-sequence removals.
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
return c.Recurrent.Remove(seq, beginIndex, endIndex)
|
||||
}
|
||||
|
||||
@@ -1,498 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
checkpointCountDefault = 32
|
||||
checkpointMinPosDefault = int32(16)
|
||||
checkpointIntervalDefault = int32(1280)
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
delta map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.delta {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.deltaStates {
|
||||
if entry.delta == nil || entry.delta[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.delta != nil {
|
||||
if dstEntry.delta == nil {
|
||||
dstEntry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.delta {
|
||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.delta == nil {
|
||||
entry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.delta[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointDeltaCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
||||
entry.delta[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointDelta(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestBackend(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
_ = f.Close()
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 1, 1)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
||||
backend := newTestBackend(t)
|
||||
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.slotForSeq[2] = 0
|
||||
cache.refCount[0] = 2
|
||||
cache.refCount[1] = 0
|
||||
cache.freeSlots = []int{1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[1] != 1 {
|
||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[2] != 0 {
|
||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
||||
}
|
||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
||||
}
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
// Only set conv checkpoint, not delta - making it incomplete
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.delta is not set, so checkpoint is incomplete
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Don't set convStates/deltaStates - with no layers to check,
|
||||
// entryComplete will return true as long as entry.pos >= 0
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
// Test that restoreComplete returns true when no layers need checkpoints
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
// Fill the buffer
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Create fake tensor data in the first entry's maps
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
||||
|
||||
// Record another entry, which should wrap around and overwrite entry 0
|
||||
store.record(40)
|
||||
|
||||
// Verify the maps are still present (we reuse tensors)
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].delta == nil {
|
||||
t.Fatalf("expected delta map to be preserved on reuse")
|
||||
}
|
||||
|
||||
// Verify the new position was recorded
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
// Test behavior when buffer is exactly at capacity
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
// Verify both checkpoints are accessible
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
// Test behavior with zero-size buffer
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
// Test pruning that removes all checkpoints
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Prune everything by setting threshold below all positions
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
// When all checkpoints are pruned, lastPos is reset to -1
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,9 @@ type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
@@ -96,7 +98,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
@@ -106,24 +107,42 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
var beta ml.Tensor
|
||||
var alpha ml.Tensor
|
||||
qwen35BetaGateLayout := false
|
||||
switch {
|
||||
case gdn.SSMBetaAlpha != nil:
|
||||
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Split beta and alpha
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Reshape to merge head dimensions
|
||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
beta = b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
|
||||
// qwen35 path: beta/alpha are separate projections.
|
||||
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
qwen35BetaGateLayout = true
|
||||
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||
if qwen35BetaGateLayout {
|
||||
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
}
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
@@ -172,16 +191,20 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
|
||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||
if numKHeads != numVHeads {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
if opts.vHeadReordered {
|
||||
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||
} else {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
}
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
@@ -189,7 +212,9 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
if opts.masks == nil {
|
||||
opts.masks = createMasks(ctx)
|
||||
}
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/model/models/qwen3vl"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
@@ -41,10 +45,15 @@ type Options struct {
|
||||
ssmNGroup int // num_k_heads
|
||||
ssmDtRank int // num_v_heads
|
||||
convKernelSize int // SSM conv kernel size
|
||||
vHeadReordered bool
|
||||
|
||||
// Per-layer type from GGUF metadata
|
||||
isRecurrent []bool
|
||||
|
||||
// RoPE mode config (used by qwen35/qwen35moe)
|
||||
mropeSections []int
|
||||
mropeInterleaved bool
|
||||
|
||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||
masks *Masks
|
||||
}
|
||||
@@ -54,7 +63,17 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
var opts []func(*rope.Options)
|
||||
if len(o.mropeSections) > 0 {
|
||||
if o.mropeInterleaved {
|
||||
opts = append(opts, rope.WithInterleaveMRoPE(o.mropeSections))
|
||||
} else {
|
||||
opts = append(opts, rope.WithMRoPE(o.mropeSections))
|
||||
}
|
||||
} else {
|
||||
opts = append(opts, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
@@ -214,20 +233,193 @@ type Model struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
Vision *qwen3vl.VisionModel `gguf:"v"`
|
||||
|
||||
ImageProcessor *qwen3vl.ImageProcessor
|
||||
|
||||
*Options
|
||||
|
||||
positionCache []int32
|
||||
imageToken int32
|
||||
visionStart int32
|
||||
visionEnd int32
|
||||
spatialMergeSize uint32
|
||||
}
|
||||
|
||||
func (m *Model) mapPosition(id int32) int32 {
|
||||
if id < int32(len(m.positionCache)) {
|
||||
return m.positionCache[id]
|
||||
}
|
||||
if len(m.positionCache) > 0 {
|
||||
return id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (m *Model) buildPositions(ctx ml.Context, batch input.Batch) ml.Tensor {
|
||||
if len(m.mropeSections) == 0 {
|
||||
return ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
}
|
||||
|
||||
// ggml MRoPE expects [time, height, width, extra] for each token.
|
||||
positionSlice := [][]int32{
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
make([]int32, len(batch.Positions)),
|
||||
}
|
||||
|
||||
for i, id := range batch.Positions {
|
||||
p := m.mapPosition(id)
|
||||
positionSlice[0][i] = p
|
||||
positionSlice[1][i] = p
|
||||
positionSlice[2][i] = p
|
||||
}
|
||||
|
||||
if m.Vision != nil {
|
||||
for _, mi := range batch.Multimodal {
|
||||
grid, ok := mi.Multimodal[0].Data.(*qwen3vl.Grid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
w := max(1, grid.Width/int(m.spatialMergeSize))
|
||||
for i := range mi.Multimodal[0].Tensor.Dim(1) {
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if m.Vision == nil || m.ImageProcessor == nil || len(m.Vision.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, grid, err := m.ImageProcessor.ProcessImage(ctx, img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs, deepstackVisualEmbeds := m.Vision.Forward(ctx, pixelValues, grid)
|
||||
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
|
||||
for i := range deepstackVisualEmbeds {
|
||||
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
type modelInput struct {
|
||||
*input.Input
|
||||
position int32
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
m.positionCache = m.positionCache[:0]
|
||||
return slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for i := range inputs {
|
||||
s := []modelInput{{Input: inputs[i]}}
|
||||
if mm := inputs[i].Multimodal; mm != nil {
|
||||
t := mm[0].Tensor
|
||||
s = slices.Repeat([]modelInput{
|
||||
{
|
||||
position: int32(i + 1),
|
||||
Input: &input.Input{Token: m.imageToken},
|
||||
},
|
||||
}, t.Dim(1)+2)
|
||||
|
||||
s[0] = modelInput{
|
||||
Input: &input.Input{Token: m.visionStart},
|
||||
position: int32(i),
|
||||
}
|
||||
|
||||
s[len(s)-1] = modelInput{
|
||||
Input: &input.Input{Token: m.visionEnd},
|
||||
position: int32(i + mm[0].Data.(*qwen3vl.Grid).Width/int(m.spatialMergeSize) + 1),
|
||||
}
|
||||
|
||||
s[1] = modelInput{
|
||||
Input: &input.Input{
|
||||
Token: m.imageToken,
|
||||
Multimodal: inputs[i].Multimodal,
|
||||
MultimodalHash: inputs[i].MultimodalHash,
|
||||
SameBatch: t.Dim(1),
|
||||
},
|
||||
position: int32(i + 1),
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range s {
|
||||
position := e.position
|
||||
if position == 0 && len(m.positionCache) > 0 {
|
||||
position = m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
|
||||
m.positionCache = append(m.positionCache, position)
|
||||
if !yield(e.Input) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
positions := m.buildPositions(ctx, batch)
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
hiddenStates = hiddenStates.Duplicate(ctx)
|
||||
|
||||
var deepstackVisualEmbeds []ml.Tensor
|
||||
for _, mi := range batch.Multimodal {
|
||||
visionOutputs := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
|
||||
for i, mm := range mi.Multimodal[1:] {
|
||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||
}
|
||||
}
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
m.Options.masks = nil
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i < len(deepstackVisualEmbeds) {
|
||||
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
// Create masks once per forward pass
|
||||
m.Options.masks = createMasks(ctx)
|
||||
// Masks are allocated lazily only for chunked recurrent prefill.
|
||||
m.Options.masks = nil
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
@@ -249,10 +441,17 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
shift = shift.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
||||
}
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
var (
|
||||
_ model.Model = (*Model)(nil)
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
@@ -303,6 +502,22 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
}
|
||||
|
||||
mropeSections := c.Ints("mrope_sections", nil)
|
||||
if len(mropeSections) == 0 {
|
||||
mropeSections = c.Ints("rope.mrope_section", nil)
|
||||
}
|
||||
if len(mropeSections) == 0 {
|
||||
mropeSections = c.Ints("rope.dimension_sections", nil)
|
||||
}
|
||||
if len(mropeSections) > 4 {
|
||||
mropeSections = mropeSections[:4]
|
||||
}
|
||||
|
||||
ropeType := c.String("rope.scaling.type")
|
||||
if ropeType == "" {
|
||||
ropeType = c.String("rope.type")
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
@@ -318,7 +533,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeType: ropeType,
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
@@ -331,7 +546,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
if !yield(int(section)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
@@ -353,6 +577,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||
}
|
||||
|
||||
var vision *qwen3vl.VisionModel
|
||||
var imageProcessor *qwen3vl.ImageProcessor
|
||||
if c.Uint("vision.block_count", 0) > 0 {
|
||||
vision = qwen3vl.NewVisionModel(c)
|
||||
processor := qwen3vl.NewImageProcessor(c)
|
||||
imageProcessor = &processor
|
||||
}
|
||||
|
||||
spatialMergeSize := c.Uint("vision.spatial_merge_size", 2)
|
||||
if spatialMergeSize == 0 {
|
||||
spatialMergeSize = 2
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
@@ -371,8 +608,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
Layers: layers,
|
||||
Vision: vision,
|
||||
ImageProcessor: imageProcessor,
|
||||
Options: opts,
|
||||
imageToken: int32(c.Uint("image_token_id", 151655)),
|
||||
visionStart: int32(c.Uint("vision_start_token_id", 151652)),
|
||||
visionEnd: int32(c.Uint("vision_end_token_id", 151653)),
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||
@@ -380,5 +623,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen35", New)
|
||||
model.Register("qwen35moe", New)
|
||||
model.Register("qwen3next", New)
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ type ImageProcessor struct {
|
||||
imageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
// NewImageProcessor creates a new image processor with default values.
|
||||
func NewImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
|
||||
@@ -189,8 +189,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: NewVisionModel(c),
|
||||
ImageProcessor: NewImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
||||
|
||||
@@ -238,8 +238,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
return hiddenStates, deepstackStates
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
// NewVisionModel creates a new instance of the Qwen vision model.
|
||||
func NewVisionModel(c fs.Config) *VisionModel {
|
||||
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
|
||||
model := &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||
|
||||
@@ -52,9 +52,13 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
p = &Qwen3VLParser{hasThinkingSupport: false}
|
||||
p = &Qwen3VLParser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-vl-thinking":
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true}
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3.5", "qwen3-vl":
|
||||
// qwen3.5 and qwen3-vl share one parser with runtime thinking toggle.
|
||||
// Default is non-thinking unless think=true is requested.
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true, defaultThinking: false}
|
||||
case "ministral":
|
||||
p = &MinistralParser{hasThinkingSupport: false}
|
||||
case "passthrough":
|
||||
|
||||
@@ -59,6 +59,7 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
{"qwen3-coder"},
|
||||
{"lfm2"},
|
||||
{"lfm2-thinking"},
|
||||
{"qwen3.5"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ type Qwen3VLParser struct {
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) HasToolSupport() bool {
|
||||
@@ -39,9 +40,18 @@ func (p *Qwen3VLParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
|
||||
func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
if !p.HasThinkingSupport() {
|
||||
|
||||
thinkingEnabled := false
|
||||
if p.HasThinkingSupport() {
|
||||
if thinkValue != nil {
|
||||
thinkingEnabled = thinkValue.Bool()
|
||||
} else {
|
||||
thinkingEnabled = p.defaultThinking
|
||||
}
|
||||
}
|
||||
if !thinkingEnabled {
|
||||
p.state = CollectingContent
|
||||
return
|
||||
}
|
||||
@@ -56,7 +66,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
|
||||
|
||||
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.setInitialState(lastMessage)
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
// parser.state = CollectingThinkingContent
|
||||
|
||||
@@ -385,7 +385,7 @@ func TestQwen3VLParserState(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking, defaultThinking: tc.hasThinking}
|
||||
parser.Init(nil, tc.last, nil)
|
||||
if parser.state != tc.wantState {
|
||||
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
|
||||
@@ -436,7 +436,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
@@ -499,7 +499,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
@@ -522,7 +522,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
|
||||
func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
|
||||
// last message is assistant with content ⇒ start in CollectingContent
|
||||
last := &api.Message{Role: "assistant", Content: "has content"}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
type step struct {
|
||||
@@ -749,7 +749,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
@@ -858,7 +858,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init([]api.Tool{}, tc.prefillMsg, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
|
||||
47
model/parsers/qwen3vl_toggle_test.go
Normal file
47
model/parsers/qwen3vl_toggle_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3VLParserThinkToggle(t *testing.T) {
|
||||
t.Run("thinking enabled by default when configured", func(t *testing.T) {
|
||||
p := &Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
p.Init(nil, nil, nil)
|
||||
if p.state != CollectingThinkingContent {
|
||||
t.Fatalf("state = %v, want %v", p.state, CollectingThinkingContent)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking disabled by default", func(t *testing.T) {
|
||||
p := &Qwen3VLParser{hasThinkingSupport: true, defaultThinking: false}
|
||||
p.Init(nil, nil, nil)
|
||||
if p.state != CollectingContent {
|
||||
t.Fatalf("state = %v, want %v", p.state, CollectingContent)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking disabled at runtime", func(t *testing.T) {
|
||||
p := &Qwen3VLParser{hasThinkingSupport: true, defaultThinking: true}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
if p.state != CollectingContent {
|
||||
t.Fatalf("state = %v, want %v", p.state, CollectingContent)
|
||||
}
|
||||
|
||||
content, thinking, calls, err := p.Add("plan</think>answer", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "plan</think>answer" {
|
||||
t.Fatalf("content = %q, want %q", content, "plan</think>answer")
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("thinking = %q, want empty", thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("calls = %d, want 0", len(calls))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
)
|
||||
|
||||
type Qwen3VLRenderer struct {
|
||||
isThinking bool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
|
||||
useImgTags bool
|
||||
}
|
||||
@@ -31,8 +32,16 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
||||
return subSb.String()
|
||||
}
|
||||
|
||||
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
thinking := false
|
||||
if r.hasThinkingSupport {
|
||||
if thinkValue != nil {
|
||||
thinking = thinkValue.Bool()
|
||||
} else {
|
||||
thinking = r.defaultThinking
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
@@ -76,13 +85,13 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
} else if message.Role == "assistant" {
|
||||
contentReasoning := ""
|
||||
|
||||
if r.isThinking {
|
||||
if r.hasThinkingSupport {
|
||||
if message.Thinking != "" {
|
||||
contentReasoning = message.Thinking
|
||||
}
|
||||
}
|
||||
|
||||
if r.isThinking && i > lastQueryIndex {
|
||||
if thinking && i > lastQueryIndex {
|
||||
if i == len(messages)-1 || contentReasoning != "" {
|
||||
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
|
||||
if content != "" {
|
||||
@@ -125,8 +134,12 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
// prefill at the end
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
if r.isThinking {
|
||||
if thinking {
|
||||
sb.WriteString("<think>\n")
|
||||
} else if r.hasThinkingSupport {
|
||||
// In nothink mode, explicitly close any latent think block so
|
||||
// checkpoints that default to thinking start directly in content.
|
||||
sb.WriteString("</think>\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,7 +509,7 @@ I'll check.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Qwen3VLRenderer{isThinking: false, useImgTags: tt.useImgTags}).Render(tt.msgs, tt.tools, nil)
|
||||
rendered, err := (&Qwen3VLRenderer{hasThinkingSupport: false, useImgTags: tt.useImgTags}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -323,7 +324,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Qwen3VLRenderer{isThinking: true}).Render(tt.msgs, tt.tools, nil)
|
||||
rendered, err := (&Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: true}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -334,6 +335,51 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLRendererThinkToggle(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
}
|
||||
|
||||
t.Run("thinking enabled by default when configured", func(t *testing.T) {
|
||||
r := &Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: true}
|
||||
out, err := r.Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(out, "<think>\n") {
|
||||
t.Fatalf("expected thinking prefill in output, got: %q", out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking disabled by default", func(t *testing.T) {
|
||||
r := &Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: false}
|
||||
out, err := r.Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(out, "<think>\n") {
|
||||
t.Fatalf("did not expect thinking prefill in output, got: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(out, "<|im_start|>assistant\n</think>\n") {
|
||||
t.Fatalf("unexpected assistant prefill: %q", out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking disabled at runtime", func(t *testing.T) {
|
||||
r := &Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: true}
|
||||
out, err := r.Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(out, "<think>\n") {
|
||||
t.Fatalf("did not expect thinking prefill in output, got: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(out, "<|im_start|>assistant\n</think>\n") {
|
||||
t.Fatalf("unexpected assistant prefill: %q", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatToolCallArgumentThinkingVL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -51,10 +51,15 @@ func rendererForName(name string) Renderer {
|
||||
renderer := &Qwen3CoderRenderer{}
|
||||
return renderer
|
||||
case "qwen3-vl-instruct":
|
||||
renderer := &Qwen3VLRenderer{isThinking: false, useImgTags: RenderImgTags}
|
||||
renderer := &Qwen3VLRenderer{hasThinkingSupport: false, defaultThinking: false, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3-vl-thinking":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
renderer := &Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3.5", "qwen3-vl":
|
||||
// qwen3.5 and qwen3-vl share one renderer with runtime think toggle.
|
||||
// Default is non-thinking unless think=true is requested.
|
||||
renderer := &Qwen3VLRenderer{hasThinkingSupport: true, defaultThinking: false, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
|
||||
@@ -29,17 +29,27 @@ func TestRegisterCustomRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuiltInRendererStillWorks(t *testing.T) {
|
||||
// Test that qwen3-coder still works
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{name: "qwen3-coder"},
|
||||
{name: "qwen3.5"},
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result from qwen3-coder renderer")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := RenderWithRenderer(tt.name, messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == "" {
|
||||
t.Fatalf("expected non-empty result from %s renderer", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
@@ -58,7 +59,7 @@ func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
|
||||
switch {
|
||||
// Full attention
|
||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||
@@ -79,6 +80,10 @@ func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
// SSM
|
||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_beta.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
@@ -287,8 +292,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3nextQuantType(name); ok {
|
||||
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3LinearAttnQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +166,60 @@ func TestGetTensorNewType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
arch string
|
||||
tensor string
|
||||
fileType fsggml.FileType
|
||||
expected fsggml.TensorType
|
||||
}{
|
||||
{
|
||||
name: "qwen35_beta",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_beta.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35_alpha",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_alpha.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35moe_attn_qkv",
|
||||
arch: "qwen35moe",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "non_qwen35_falls_back",
|
||||
arch: "foo",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
kv := fsggml.KV{"general.architecture": tt.arch}
|
||||
got := newType(&fsggml.Tensor{
|
||||
Name: tt.tensor,
|
||||
Shape: []uint64{256, 256},
|
||||
Kind: uint32(fsggml.TensorTypeF16),
|
||||
}, kv, &quantizeState{}, tt.fileType)
|
||||
|
||||
if got != tt.expected {
|
||||
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
@@ -76,6 +76,22 @@ func shouldUseHarmony(model *Model) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultThinkingForModel(model *Model) bool {
|
||||
// qwen3.5/qwen3-vl are runtime-toggle models that should default to
|
||||
// non-thinking unless explicitly requested.
|
||||
if slices.Contains([]string{"qwen3.5", "qwen3-vl"}, model.Config.Parser) ||
|
||||
slices.Contains([]string{"qwen3.5", "qwen3-vl"}, model.Config.Renderer) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Preserve legacy default-on behavior for explicit thinking variants.
|
||||
if model.Config.Parser == "qwen3-vl-thinking" || model.Config.Renderer == "qwen3-vl-thinking" {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
@@ -385,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
req.Think = &api.ThinkValue{Value: defaultThinkingForModel(m)}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
@@ -2149,7 +2165,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
req.Think = &api.ThinkValue{Value: defaultThinkingForModel(m)}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
|
||||
@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
@@ -264,13 +264,16 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
}
|
||||
}
|
||||
|
||||
resolvedParser := resolveParserName(opts.Modelfile, parserName)
|
||||
resolvedRenderer := resolveRendererName(opts.Modelfile, rendererName)
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
Parser: resolvedParser,
|
||||
Renderer: resolvedRenderer,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -427,6 +430,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3_5") {
|
||||
return "qwen3.5"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -441,6 +447,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3_5") {
|
||||
return "qwen3.5"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -475,6 +484,9 @@ func getRendererName(modelDir string) string {
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3_5") {
|
||||
return "qwen3.5"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
@@ -489,6 +501,9 @@ func getRendererName(modelDir string) string {
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3_5") {
|
||||
return "qwen3.5"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
|
||||
40
x/create/client/create_qwen35_test.go
Normal file
40
x/create/client/create_qwen35_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeConfig(t *testing.T, dir, cfg string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(cfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserAndRendererInferenceQwen35(t *testing.T) {
|
||||
t.Run("qwen3.5 architecture uses qwen3.5 runtime-toggle stack", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeConfig(t, dir, `{"architectures":["Qwen3_5ForConditionalGeneration"],"model_type":"qwen3_5"}`)
|
||||
|
||||
if got, want := getParserName(dir), "qwen3.5"; got != want {
|
||||
t.Fatalf("getParserName() = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := getRendererName(dir), "qwen3.5"; got != want {
|
||||
t.Fatalf("getRendererName() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("qwen3 legacy inference unchanged", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeConfig(t, dir, `{"architectures":["Qwen3ForCausalLM"],"model_type":"qwen3"}`)
|
||||
|
||||
if got, want := getParserName(dir), "qwen3"; got != want {
|
||||
t.Fatalf("getParserName() = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := getRendererName(dir), "qwen3-coder"; got != want {
|
||||
t.Fatalf("getRendererName() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user