Compare commits

..

9 Commits

Author SHA1 Message Date
Patrick Devine
857cffd22a bugfix: fix crash bug in token cache logic
This change fixes a problem in the token cache logic to avoid panics caused by empty token arrays
by ensuring at least one token remains on full cache hits in the relevant function. The happens
if there is an exact match in the cache on subsequent generations.
2026-02-26 18:35:44 -08:00
Jeffrey Morgan
d98dda4676 model: fix qwen3 tool calling in thinking (#14477)
Align Qwen parser behavior with Transformers serve by allowing <tool_call> parsing while still in thinking collection.

Changes:

- qwen3vl: detect <tool_call> before </think> in thinking state and transition to tool parsing

- qwen3: same thinking-state tool detection and partial-tag overlap handling

- tests: update qwen3vl thinking/tool interleaving expectations

- tests: add qwen3 cases for tool call before </think> and split <tool_call> streaming
2026-02-26 16:13:18 -08:00
Eva H
d69ddc1edc fix: window app crash on startup when update is pending (#14451) 2026-02-26 16:47:12 -05:00
Eva H
9bf41969f0 app: fix first update check delayed by 1 hour (#14427) 2026-02-25 18:29:55 -05:00
Jesse Gross
0f23b7bff5 mlxrunner: Cancel in-flight requests when the client disconnects
Currently, a canceled request can result in computation continuing
in the background to completion. It can also trigger a deadlock
when there is nobody to read the output tokens and the pipeline
cannot continue to the next request.
2026-02-25 14:00:42 -08:00
Jesse Gross
4e57d2094e mlxrunner: Simplify pipeline memory and cache management
Particularly in error cases, it can be difficult to ensure that
all pinned memory is unpinned, MLX buffers are released and cache
state is consistent. This encapsulates those pieces and sets up
proper deferrals so that this happens automatically on exit.
2026-02-25 14:00:42 -08:00
Jeffrey Morgan
7f9efd53df model: add support for qwen3.5-27b model (#14415) 2026-02-25 01:09:58 -08:00
Jeffrey Morgan
da70c3222e model: support for qwen3.5 architecture (#14378) 2026-02-24 20:08:05 -08:00
Bruce MacDonald
9d902d63ce ggml: ensure tensor size is valid (#14406)
When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
2026-02-24 21:52:44 -04:00
48 changed files with 2342 additions and 2006 deletions

View File

@@ -35,6 +35,7 @@ import (
var (
wv = &Webview{}
uiServerPort int
appStore *store.Store
)
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
@@ -208,6 +209,7 @@ func main() {
uiServerPort = port
st := &store.Store{}
appStore = st
// Enable CORS in development mode
if devMode {
@@ -294,8 +296,15 @@ func main() {
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
// before that would dereference a nil tray callback.
// TODO: refactor so the update check runs after platform init on all platforms.
if runtime.GOOS == "windows" {
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
} else {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
@@ -360,8 +369,7 @@ func startHiddenTasks() {
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
settings, err := appStore.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {

View File

@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
}
func UpdateAvailable(ver string) error {
if app.t == nil {
slog.Debug("tray not yet initialized, skipping update notification")
return nil
}
return app.t.UpdateAvailable(ver)
}
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
log.Fatalf("Failed to start: %s", err)
}
// Check for pending updates now that the tray is initialized.
// The platform-independent check in app.go fires before osRun,
// when app.t is still nil, so we must re-check here.
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

View File

@@ -1,202 +0,0 @@
import { describe, it, expect } from "vitest";
import { sanitize } from "hast-util-sanitize";
import { defaultSchema } from "rehype-sanitize";
// Mirror the sanitizeSchema from StreamingMarkdownContent.tsx
const sanitizeSchema = {
...defaultSchema,
tagNames: [...(defaultSchema.tagNames || []), "ol-citation"],
attributes: {
...defaultSchema.attributes,
div: [
...(defaultSchema.attributes?.div || []),
["className", "math", "math-display"],
],
span: [
...(defaultSchema.attributes?.span || []),
["className", "math", "math-inline"],
],
"ol-citation": ["cursor", "start", "end"],
},
strip: ["script", "style"],
};
// Helper to create a hast element node
function h(
tagName: string,
properties: Record<string, unknown>,
children: any[] = [],
): any {
return { type: "element", tagName, properties, children };
}
function text(value: string): any {
return { type: "text", value };
}
function root(...children: any[]): any {
return { type: "root", children };
}
describe("sanitizeSchema", () => {
it("should strip <style> tags and their content", () => {
const tree = root(
h("style", {}, [
text("body { background: red; } button { background: linear-gradient(blue, green); }"),
]),
h("p", {}, [text("Hello world")]),
);
const result = sanitize(tree, sanitizeSchema);
// <style> should be completely stripped (including content)
const hasStyle = JSON.stringify(result).includes("background");
expect(hasStyle).toBe(false);
// <p> should survive
expect(result.children).toHaveLength(1);
expect(result.children[0].tagName).toBe("p");
});
it("should strip <script> tags and their content", () => {
const tree = root(
h("script", {}, [text("alert('xss')")]),
h("p", {}, [text("Safe content")]),
);
const result = sanitize(tree, sanitizeSchema);
const hasScript = JSON.stringify(result).includes("alert");
expect(hasScript).toBe(false);
expect(result.children).toHaveLength(1);
expect(result.children[0].tagName).toBe("p");
});
it("should strip <iframe> tags", () => {
const tree = root(
h("iframe", { src: "https://evil.com" }, []),
h("p", {}, [text("Safe content")]),
);
const result = sanitize(tree, sanitizeSchema);
const hasIframe = result.children.some(
(c: any) => c.tagName === "iframe",
);
expect(hasIframe).toBe(false);
});
it("should preserve math block elements (div.math.math-display)", () => {
const tree = root(
h("div", { className: ["math", "math-display"] }, [
text("E = mc^2"),
]),
);
const result = sanitize(tree, sanitizeSchema);
expect(result.children).toHaveLength(1);
expect(result.children[0].tagName).toBe("div");
expect(result.children[0].properties.className).toEqual([
"math",
"math-display",
]);
});
it("should preserve math inline elements (span.math.math-inline)", () => {
const tree = root(
h("span", { className: ["math", "math-inline"] }, [text("x^2")]),
);
const result = sanitize(tree, sanitizeSchema);
expect(result.children).toHaveLength(1);
expect(result.children[0].tagName).toBe("span");
expect(result.children[0].properties.className).toEqual([
"math",
"math-inline",
]);
});
it("should preserve ol-citation elements with attributes", () => {
const tree = root(
h("ol-citation", { cursor: "1", start: "25", end: "30" }, []),
);
const result = sanitize(tree, sanitizeSchema);
expect(result.children).toHaveLength(1);
expect(result.children[0].tagName).toBe("ol-citation");
expect(result.children[0].properties.cursor).toBe("1");
expect(result.children[0].properties.start).toBe("25");
expect(result.children[0].properties.end).toBe("30");
});
it("should preserve code elements with language classes", () => {
const tree = root(
h("pre", {}, [
h("code", { className: ["language-python"] }, [
text("print('hello')"),
]),
]),
);
const result = sanitize(tree, sanitizeSchema);
expect(result.children).toHaveLength(1);
const code = result.children[0].children[0];
expect(code.tagName).toBe("code");
expect(code.properties.className).toEqual(["language-python"]);
});
it("should preserve standard markdown elements", () => {
const tree = root(
h("h1", {}, [text("Title")]),
h("p", {}, [
text("Some "),
h("strong", {}, [text("bold")]),
text(" and "),
h("em", {}, [text("italic")]),
text(" text."),
]),
h("ul", {}, [
h("li", {}, [text("Item 1")]),
h("li", {}, [text("Item 2")]),
]),
h("a", { href: "https://example.com" }, [text("A link")]),
);
const result = sanitize(tree, sanitizeSchema);
const tagNames = result.children.map((c: any) => c.tagName);
expect(tagNames).toEqual(["h1", "p", "ul", "a"]);
});
it("should strip model-generated HTML page that would corrupt the UI", () => {
// Simulate a model generating a full HTML page
const tree = root(
h("style", {}, [
text(`
* { margin: 0; padding: 0; }
button { background: linear-gradient(to right, #ff0000, #0000ff); }
.some-class { font-size: 72px; }
`),
]),
h("div", {}, [
h("h1", {}, [text("My Generated Page")]),
h("p", {}, [text("This is model-generated content")]),
]),
);
const result = sanitize(tree, sanitizeSchema);
// Style tag and its content should be gone
const serialized = JSON.stringify(result);
expect(serialized).not.toContain("linear-gradient");
expect(serialized).not.toContain("margin: 0");
// The safe content should remain
expect(serialized).toContain("My Generated Page");
expect(serialized).toContain("model-generated content");
});
});

View File

@@ -1,36 +1,10 @@
import React from "react";
import { Streamdown, defaultRemarkPlugins, defaultRehypePlugins } from "streamdown";
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
import type { PluggableList } from "unified";
import { Streamdown, defaultRemarkPlugins } from "streamdown";
import remarkCitationParser from "@/utils/remarkCitationParser";
import CopyButton from "./CopyButton";
import type { BundledLanguage } from "shiki";
import { highlighter } from "@/lib/highlighter";
// Extend GitHub's default sanitization schema to support math rendering
// and custom citation elements while stripping dangerous tags like <style>
// and <script> that can leak from model-generated HTML content.
const sanitizeSchema = {
...defaultSchema,
tagNames: [
...(defaultSchema.tagNames || []),
"ol-citation",
],
attributes: {
...defaultSchema.attributes,
div: [
...(defaultSchema.attributes?.div || []),
["className", "math", "math-display"],
],
span: [
...(defaultSchema.attributes?.span || []),
["className", "math", "math-inline"],
],
"ol-citation": ["cursor", "start", "end"],
},
strip: ["script", "style"],
};
interface StreamingMarkdownContentProps {
content: string;
isStreaming?: boolean;
@@ -161,18 +135,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
];
}, []);
// Build rehype plugins: keep defaults (harden, raw, katex) but add
// sanitization after raw HTML parsing to prevent model-generated HTML
// (e.g. <style> tags) from affecting the UI
const rehypePlugins: PluggableList = React.useMemo(() => {
return [
defaultRehypePlugins.harden,
defaultRehypePlugins.raw,
[rehypeSanitize, sanitizeSchema],
defaultRehypePlugins.katex,
] as PluggableList;
}, []);
return (
<div
className={`
@@ -249,7 +211,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
parseIncompleteMarkdown={isStreaming}
isAnimating={isStreaming}
remarkPlugins={remarkPlugins}
rehypePlugins={rehypePlugins}
controls={false}
components={{
pre: CodeBlock,

View File

@@ -289,6 +289,7 @@ func (u *Updater) TriggerImmediateCheck() {
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
u.checkNow <- struct{}{} // Trigger first check after initial delay
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
@@ -333,7 +334,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
continue
}
// Download successful - show tray notification (regardless of toggle state)
// Download successful - show tray notification
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)

View File

@@ -351,10 +351,13 @@ func TestTriggerImmediateCheck(t *testing.T) {
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// Wait for the initial check that fires after the initial delay
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("initial check did not happen")
}
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check

View File

@@ -320,7 +320,7 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &lfm2Model{}
case "Lfm2VlForConditionalGeneration":
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM":
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}

View File

@@ -1,6 +1,7 @@
package convert
import (
"encoding/json"
"fmt"
"io/fs"
"math"
@@ -13,8 +14,21 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
type qwen3NextRopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
MropeSection []int32 `json:"mrope_section"`
}
type qwen3NextRopeParams struct {
MRopeInterleaved bool `json:"mrope_interleaved"`
MropeSection []int32 `json:"mrope_section"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
}
type qwen3NextTextConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
NormTopkProb *bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
FullAttentionInterval uint32 `json:"full_attention_interval"`
LayerTypes []string `json:"layer_types"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
@@ -43,16 +58,102 @@ type qwen3NextModel struct {
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
}
type qwen3NextVisionConfig struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_channels"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
type qwen3NextModel struct {
ModelParameters
qwen3NextTextConfig
TextConfig *qwen3NextTextConfig `json:"text_config"`
VisionModel qwen3NextVisionConfig `json:"vision_config"`
ImageTokenID uint32 `json:"image_token_id"`
VisionStartTokenID uint32 `json:"vision_start_token_id"`
VisionEndTokenID uint32 `json:"vision_end_token_id"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
if q.TextConfig != nil {
q.qwen3NextTextConfig = *q.TextConfig
}
if q.RopeTheta == 0 {
q.RopeTheta = q.RopeParameters.RopeTheta
}
if q.PartialRotaryFactor == 0 {
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
}
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
q.RopeScaling.Type = q.RopeParameters.RopeType
}
// Pull vision preprocessing fields when present.
if q.VisionModel.Depth > 0 {
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
var pre struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
if json.Unmarshal(bts, &pre) == nil {
if q.VisionModel.PatchSize == 0 {
q.VisionModel.PatchSize = pre.PatchSize
}
if q.VisionModel.TemporalPatchSize == 0 {
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
}
if q.VisionModel.SpatialMergeSize == 0 {
q.VisionModel.SpatialMergeSize = pre.MergeSize
}
if q.VisionModel.Size.ShortestEdge == 0 {
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge == 0 {
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) == 0 {
q.VisionModel.ImageMean = pre.ImageMean
}
if len(q.VisionModel.ImageStd) == 0 {
q.VisionModel.ImageStd = pre.ImageStd
}
}
}
}
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
if _, err := q.kvHeadCounts(); err != nil {
return err
}
return nil
}
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
if len(q.LayerTypes) > 0 {
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
hasRecurrent := false
for i := range q.NumHiddenLayers {
layerType := ""
if i < uint32(len(q.LayerTypes)) {
layerType = q.LayerTypes[i]
}
if layerType == "full_attention" {
kv[i] = q.NumKeyValueHeads
hasFull = true
} else {
hasRecurrent = true
}
}
if !hasFull || !hasRecurrent {
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
}
return kv, nil
}
if q.FullAttentionInterval == 0 {
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
kv[i] = q.NumKeyValueHeads
hasFull = true
}
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return kv, nil
}
func (q *qwen3NextModel) ropeSections() []int32 {
if len(q.RopeParameters.MropeSection) > 0 {
return q.RopeParameters.MropeSection
}
return q.RopeScaling.MropeSection
}
func (q *qwen3NextModel) shouldReorderVHeads() bool {
modelType := strings.ToLower(q.ModelType)
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
return false
}
for _, arch := range q.Architectures {
arch = strings.ToLower(arch)
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
return false
}
}
// Default to qwen3.5 layout for all other qwen3next-family imports.
return true
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
arch := "qwen35"
if q.NumExperts > 0 {
arch = "qwen35moe"
}
kv["general.architecture"] = arch
kv["tokenizer.ggml.pre"] = "qwen35"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if sections := q.ropeSections(); len(sections) > 0 {
kv["mrope_sections"] = sections
kv["rope.mrope_section"] = sections
kv["rope.dimension_sections"] = sections
}
if q.RopeParameters.MRopeInterleaved {
kv["rope.mrope_interleaved"] = true
}
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.NormTopkProb != nil {
kv["norm_top_k_prob"] = *q.NormTopkProb
}
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.state_size"] = q.LinearKeyHeadDim
kv["ssm.group_count"] = q.LinearNumKeyHeads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
if q.shouldReorderVHeads() {
kv["ssm.v_head_reordered"] = true
}
if q.FullAttentionInterval > 0 {
kv["full_attention_interval"] = q.FullAttentionInterval
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
if headCounts, err := q.kvHeadCounts(); err == nil {
kv["attention.head_count_kv"] = headCounts
}
if q.VisionModel.Depth > 0 {
kv["vision.block_count"] = q.VisionModel.Depth
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
kv["vision.num_channels"] = q.VisionModel.InChannels
if q.VisionModel.PatchSize > 0 {
kv["vision.patch_size"] = q.VisionModel.PatchSize
}
if q.VisionModel.SpatialMergeSize > 0 {
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
}
if q.VisionModel.RMSNormEps > 0 {
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
}
if q.VisionModel.RopeTheta > 0 {
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
}
if q.VisionModel.TemporalPatchSize > 0 {
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
}
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
if q.VisionModel.Size.ShortestEdge > 0 {
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge > 0 {
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) > 0 {
kv["vision.image_mean"] = q.VisionModel.ImageMean
}
if len(q.VisionModel.ImageStd) > 0 {
kv["vision.image_std"] = q.VisionModel.ImageStd
}
}
if q.ImageTokenID > 0 {
kv["image_token_id"] = q.ImageTokenID
}
if q.VisionStartTokenID > 0 {
kv["vision_start_token_id"] = q.VisionStartTokenID
}
if q.VisionEndTokenID > 0 {
kv["vision_end_token_id"] = q.VisionEndTokenID
}
return kv
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
out = append(out, slices.Collect(splitDim(t, 1,
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
))...)
case strings.Contains(name, ".mlp.experts.down_proj"):
out = append(out, &ggml.Tensor{
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
WriterTo: t,
})
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
t.SetRepacker(q.repackSSMA())
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_qkv.weight"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackAttnQKV())
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_gate.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_dt"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_out.weight"):
if q.shouldReorderVHeads() {
// HF out_proj layout is [out_features, in_features]; reorder columns.
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackConv1D())
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
}
}
return out
}
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
}
}
func (q *qwen3NextModel) repackAttnQKV() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() || len(shape) != 2 {
return data, nil
}
rows := int(shape[0])
cols := int(shape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qDim := headK * numK
kDim := headK * numK
vDim := headV * numV
qkvDim := qDim + kDim + vDim
switch {
case rows == qkvDim:
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
// reorder only V rows from grouped -> tiled head layout.
out := make([]float32, len(data))
qkRows := qDim + kDim
qkSize := qkRows * cols
copy(out[:qkSize], data[:qkSize])
vStart := qkSize
vEnd := vStart + vDim*cols
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
copy(out[vEnd:], data[vEnd:])
return out, nil
case cols == qkvDim:
// Fallback for already-transposed [in_features, out_features] tensors.
out := make([]float32, len(data))
copy(out, data)
for r := range rows {
base := r * cols
vStart := base + qDim + kDim
vEnd := vStart + vDim
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackConv1D() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
normShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
normShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
normShape = []uint64{shape[0], shape[2]}
}
}
if len(normShape) != 2 {
return data, nil
}
rows := int(normShape[0])
cols := int(normShape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qkChannels := 2 * headK * numK
totalChannels := qkChannels + headV*numV
if qkChannels <= 0 {
return data, nil
}
switch {
case rows == totalChannels:
// HF layout after squeeze: [channels, kernel]
out := make([]float32, len(data))
prefix := qkChannels * cols
copy(out[:prefix], data[:prefix])
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[prefix:], reorderedV)
return out, nil
case cols == totalChannels:
// Fallback for transposed [kernel, channels]
out := make([]float32, len(data))
copy(out, data)
vChannels := totalChannels - qkChannels
for r := range rows {
base := r * cols
vStart := base + qkChannels
vEnd := vStart + vChannels
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackSSMA() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
result := make([]float32, len(data))
for i, v := range data {
result[i] = -float32(math.Exp(float64(v)))
}
if !q.shouldReorderVHeads() {
return result, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
}
}
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
return data, nil
}
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
if dim < 0 {
dim += len(dims)
}
if dim < 0 || dim >= len(dims) {
return data, nil
}
expected := numKHeads * numVPerK * headDim
if dims[dim] != expected {
return data, nil
}
newShape := make([]int, 0, len(dims)+2)
newShape = append(newShape, dims[:dim]...)
newShape = append(newShape, numKHeads, numVPerK, headDim)
newShape = append(newShape, dims[dim+1:]...)
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := tt.Reshape(newShape...); err != nil {
return nil, err
}
perm := make([]int, len(newShape))
for i := range perm {
perm[i] = i
}
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
tt, err := tensor.Transpose(tt, perm...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
total := 1
for _, d := range dims {
total *= d
}
if err := tt.Reshape(total); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
type qkvzSplitSpec struct {
hidden int
headKDim int
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Vision
"model.visual", "v",
"patch_embed.proj", "patch_embed",
"blocks", "blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"deepstack_merger_list", "deepstack_merger",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
// Linear attention (legacy qwen3next)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
// Linear attention (qwen35)
"linear_attn.in_proj_qkv", "attn_qkv",
"linear_attn.in_proj_z", "attn_gate",
"linear_attn.in_proj_a", "ssm_alpha",
"linear_attn.in_proj_b", "ssm_beta",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
// MoE
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
// Dense FFN
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",

View File

@@ -0,0 +1,563 @@
package convert
import (
"bytes"
"encoding/binary"
"os"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/fs/ggml"
)
func boolPtr(v bool) *bool {
return &v
}
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
t.Helper()
var b bytes.Buffer
if _, err := tensor.WriteTo(&b); err != nil {
t.Fatal(err)
}
numel := 1
for _, d := range tensor.Shape {
numel *= int(d)
}
values := make([]float32, numel)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
return values
}
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
}
}
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
Architectures: []string{"Qwen3NextForCausalLM"},
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
}
}
func TestQwen3NextKVLegacyConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 8192,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
NormTopkProb: boolPtr(true),
MoEIntermediateSize: 256,
SharedExpertIntermSize: 512,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if _, ok := kv["ssm.v_head_reordered"]; ok {
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
}
if got, want := kv["norm_top_k_prob"], true; got != want {
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
}
}
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 4096,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if _, ok := kv["norm_top_k_prob"]; ok {
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
}
}
func TestQwen35KVFromTextConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
TextConfig: &qwen3NextTextConfig{
MaxPositionEmbeddings: 16384,
HiddenSize: 1024,
NumHiddenLayers: 4,
IntermediateSize: 4096,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
HeadDim: 128,
RMSNormEPS: 1e-6,
LayerTypes: []string{
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
},
LinearConvKernelDim: 4,
LinearKeyHeadDim: 128,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 128,
RopeParameters: qwen3NextRopeParams{
MRopeInterleaved: true,
MropeSection: []int32{11, 11, 10},
RopeType: "default",
RopeTheta: 10_000_000,
PartialRotaryFactor: 0.25,
},
},
VisionModel: qwen3NextVisionConfig{
Depth: 2,
HiddenSize: 128,
NumHeads: 4,
InChannels: 3,
PatchSize: 16,
SpatialMergeSize: 2,
RMSNormEps: 1e-6,
RopeTheta: 10_000,
TemporalPatchSize: 2,
DeepstackVisualIndexes: []int32{1},
},
ImageTokenID: 1001,
VisionStartTokenID: 1002,
VisionEndTokenID: 1003,
}
m.VisionModel.Size.ShortestEdge = 224
m.VisionModel.Size.LongestEdge = 4096
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
}
mrope, ok := kv["mrope_sections"].([]int32)
if !ok {
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
}
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
}
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
if !ok {
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
}
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
}
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
}
if got, want := kv["vision.block_count"], uint32(2); got != want {
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
}
}
func TestQwen3NextReplacements(t *testing.T) {
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
}
}
func TestQwen35ReordersVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected data: got %v want %v", got, want)
}
}
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_qkv.weight",
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_out.weight",
shape: []uint64{2, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_beta.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
}
}
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
}
}
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 1},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
}
}
func TestQwen35MoePackedExperts(t *testing.T) {
m := &qwen3NextModel{
qwen3NextTextConfig: qwen3NextTextConfig{
NumHiddenLayers: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.mlp.experts.gate_up_proj",
shape: []uint64{2, 4, 3},
data: []float32{
0, 1, 2,
3, 4, 5,
6, 7, 8,
9, 10, 11,
12, 13, 14,
15, 16, 17,
18, 19, 20,
21, 22, 23,
},
},
&fakeTensor{
name: "blk.0.mlp.experts.down_proj",
shape: []uint64{2, 5, 3},
data: make([]float32, 2*5*3),
},
})
get := func(name string) *ggml.Tensor {
for _, tensor := range out {
if tensor.Name == name {
return tensor
}
}
return nil
}
gate := get("blk.0.ffn_gate_exps.weight")
if gate == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
}
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, gate), []float32{
0, 1, 2, 3, 4, 5,
12, 13, 14, 15, 16, 17,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate values: got %v want %v", got, want)
}
up := get("blk.0.ffn_up_exps.weight")
if up == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
}
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected up shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, up), []float32{
6, 7, 8, 9, 10, 11,
18, 19, 20, 21, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected up values: got %v want %v", got, want)
}
down := get("blk.0.ffn_down_exps.weight")
if down == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
}
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected down shape: got %v want %v", got, want)
}
}
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
m := &qwen3NextModel{}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ffn_gate_inp_shexp.weight",
shape: []uint64{1, 4},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
}
}

View File

@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
t.Pre = "deepseek-coder"
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
t.Pre = "qwen2"
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
t.Pre = "qwen35"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer
default:

View File

@@ -386,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "qwen35 pretokenizer",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
}
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{Model: "gpt2"},
Pre: "qwen35",
},
},
}
for _, tt := range cases {

View File

@@ -290,6 +290,7 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen35", "qwen35moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
@@ -868,7 +869,12 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
arch := f.KV().Architecture()
if slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
return true
}
if slices.Contains([]string{"gemma2"}, arch) {
return false
}
@@ -892,6 +898,7 @@ func (f GGML) FlashAttention() bool {
"nemotron_h", "nemotron_h_moe",
"olmo3",
"qwen3", "qwen3moe",
"qwen35", "qwen35moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -245,7 +245,22 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
padding := ggufPadding(offset, int64(alignment))
llm.tensorOffset = uint64(offset + padding)
// get file size to validate tensor bounds
fileSize, err := rs.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to determine file size: %w", err)
}
if _, err := rs.Seek(offset, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek back after size check: %w", err)
}
for _, tensor := range llm.tensors {
tensorEnd := llm.tensorOffset + tensor.Offset + tensor.Size()
if tensorEnd > uint64(fileSize) {
return fmt.Errorf("tensor %q offset+size (%d) exceeds file size (%d)", tensor.Name, tensorEnd, fileSize)
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to get current offset: %w", err)

View File

@@ -11,21 +11,21 @@ import (
)
func TestWriteGGUF(t *testing.T) {
b := bytes.NewBuffer(make([]byte, 2*3))
tensorData := make([]byte, 2*3*4) // 6 F32 elements = 24 bytes
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
}
rand.Shuffle(len(ts), func(i, j int) {
@@ -98,4 +98,32 @@ func TestWriteGGUF(t *testing.T) {
}
})
}
t.Run("truncated_tensor_data", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "blk.0.attn.weight", Kind: 0, Shape: []uint64{512, 2}, WriterTo: bytes.NewBuffer(make([]byte, 32))},
}
w, err := os.CreateTemp(t.TempDir(), "truncated_*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{"general.architecture": "test"}, ts); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
if _, err := Decode(r, -1); err == nil {
t.Error("Decode should reject GGUF files where tensor data extends beyond file size")
}
})
}

View File

@@ -11,9 +11,9 @@ import (
)
const (
DefaultCheckpointCount = 32
DefaultCheckpointCount = 24
DefaultCheckpointMinPos = int32(16)
DefaultCheckpointInterval = int32(1280)
DefaultCheckpointInterval = int32(1664)
)
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")

View File

@@ -195,6 +195,7 @@ type Tensor interface {
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
SetInplace(ctx Context, src Tensor, nb1, nb2, nb3, offset int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor

View File

@@ -1345,6 +1345,21 @@ func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tenso
}
}
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_inplace(
ctx.(*Context).ctx,
t.t,
src.(*Tensor).t,
C.size_t(nb1),
C.size_t(nb2),
C.size_t(nb3),
C.size_t(offset),
),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -2,595 +2,58 @@ package qwen3next
import (
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache stores:
// - a standard causal KV cache for full attention layers
// - per-sequence conv state for linear attention layers
// - per-sequence delta state for linear attention layers
//
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int // convKernelSize - 1
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
// Delta state dimensions
deltaStateSize int // headVDim * headVDim * numVHeads
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer delta state buffers (allocated lazily)
deltaCtxs map[int]ml.Context
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointDeltaCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
*kvcache.Recurrent
}
func NewHybridCache(
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
convDim, convChannels, deltaStateSize int,
) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
convDim: convDim,
convChannels: convChannels,
deltaStateSize: deltaStateSize,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
deltaCtxs: make(map[int]ml.Context),
deltaStates: make(map[int]ml.Tensor),
checkpointCount: checkpointCountDefault,
checkpointMinPos: checkpointMinPosDefault,
checkpointInterval: checkpointIntervalDefault,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointDeltaCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: deltaStateSize,
CheckpointLogPrefix: "qwen3next",
})
return &HybridCache{Recurrent: base}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.deltaCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointDeltaCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for recurrent layers
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
c.reserveCheckpoints = true
c.planCheckpoints(batch)
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero state for newly allocated slots
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = false
c.planCheckpoints(batch)
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros the recurrent state for the given slots across all layers.
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Zero conv states
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// Zero delta states
if len(c.deltaStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
for _, buf := range c.deltaStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
for _, buf := range c.deltaStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// Copy-on-write for recurrent state
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return kvcache.ErrNotSupported
}
if !c.restoreComplete(restore) {
return kvcache.ErrNotSupported
}
// If the recurrent slot is shared, detach it before applying a restore.
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
// Removal invalidates recurrent state
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *HybridCache) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.deltaStates[layer]; ok {
return buf
}
if _, ok := c.deltaCtxs[layer]; !ok {
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent delta state must stay in F32.
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
c.deltaStates[layer] = buf
return buf
}
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
// DeltaState returns the delta state for current batch sequences as
// [headVDim, headVDim*numVHeads, nSeqs].
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
}
// UpdateDeltaState writes a new delta state for current batch sequences.
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}
// Keep qwen3next behavior for partial mid-sequence removals.
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
return c.Recurrent.Remove(seq, beginIndex, endIndex)
}

View File

@@ -1,498 +0,0 @@
package qwen3next
import (
"log/slog"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
checkpointCountDefault = 32
checkpointMinPosDefault = int32(16)
checkpointIntervalDefault = int32(1280)
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
delta map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *HybridCache) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return kvcache.ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.delta {
buf := c.deltaBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.delta) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.deltaStates {
if entry.delta == nil || entry.delta[layer] == nil {
return false
}
}
return true
}
func (c *HybridCache) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.delta != nil {
if dstEntry.delta == nil {
dstEntry.delta = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.delta {
dst := c.ensureCheckpointDelta(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
if entry.delta == nil {
entry.delta = make(map[int]ml.Tensor)
}
if t, ok := entry.delta[layer]; ok {
return t
}
ctx, ok := c.checkpointDeltaCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointDeltaCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
entry.delta[layer] = t
return t
}
func (c *HybridCache) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *HybridCache) reserveCheckpointDelta(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointDelta(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -1,300 +0,0 @@
package qwen3next
import (
"errors"
"math"
"os"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
func newTestBackend(tb testing.TB) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
if err != nil {
tb.Fatal(err)
}
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
_ = f.Close()
tb.Fatal(err)
}
if err := f.Close(); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
b.Close()
})
return b
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCachePrepareRestore(t *testing.T) {
cache := NewHybridCache(nil, 1, 1, 1)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
backend := newTestBackend(t)
cache := NewHybridCache(nil, 1, 2, 2)
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
cache.slotForSeq[1] = 0
cache.slotForSeq[2] = 0
cache.refCount[0] = 2
cache.refCount[1] = 0
cache.freeSlots = []int{1}
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
t.Fatalf("Remove failed: %v", err)
}
if cache.slotForSeq[1] == cache.slotForSeq[2] {
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
}
if cache.slotForSeq[1] != 1 {
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
}
if cache.slotForSeq[2] != 0 {
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
}
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared")
}
}
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
// Only set conv checkpoint, not delta - making it incomplete
entry.conv = map[int]ml.Tensor{0: nil}
// entry.delta is not set, so checkpoint is incomplete
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, kvcache.ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Don't set convStates/deltaStates - with no layers to check,
// entryComplete will return true as long as entry.pos >= 0
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
// Test that restoreComplete returns true when no layers need checkpoints
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
// Test that ring buffer wrap-around reuses entries without clearing maps.
store := newSlotCheckpointStore(3)
// Fill the buffer
store.record(10)
store.record(20)
store.record(30)
// Create fake tensor data in the first entry's maps
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil // Simulated tensor reference
store.entries[0].delta = make(map[int]ml.Tensor)
store.entries[0].delta[0] = nil // Simulated tensor reference
// Record another entry, which should wrap around and overwrite entry 0
store.record(40)
// Verify the maps are still present (we reuse tensors)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].delta == nil {
t.Fatalf("expected delta map to be preserved on reuse")
}
// Verify the new position was recorded
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
// Test behavior when buffer is exactly at capacity
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
// Verify both checkpoints are accessible
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
// Test behavior with zero-size buffer
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
// Test pruning that removes all checkpoints
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
// Prune everything by setting threshold below all positions
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
// When all checkpoints are pruned, lastPos is reset to -1
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -37,7 +37,9 @@ type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
@@ -96,7 +98,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
@@ -106,24 +107,40 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
var beta ml.Tensor
var alpha ml.Tensor
switch {
case gdn.SSMBetaAlpha != nil:
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
// Split beta and alpha
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Reshape to merge head dimensions
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
// Keep beta layout consistent with qwen35.
// [1, numVHeads, nSeqTokens, nSeqs]
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
// qwen35 path: beta/alpha are separate projections.
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
alphaSoftplus := alphaBiased.Softplus(ctx)
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
@@ -172,16 +189,20 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
// Repeat interleave Q and K if numKHeads != numVHeads
if numKHeads != numVHeads {
repeatFactor := numVHeads / numKHeads
if opts.vHeadReordered {
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
} else {
repeatFactor := numVHeads / numKHeads
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
}
}
// Choose computation mode based on sequence length
@@ -189,7 +210,9 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
if opts.masks == nil {
opts.masks = createMasks(ctx)
}
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
}
@@ -310,9 +333,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Compute padding
@@ -324,7 +347,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q = q.Pad(ctx, 0, pad, 0, 0)
k = k.Pad(ctx, 0, pad, 0, 0)
v = v.Pad(ctx, 0, pad, 0, 0)
gate = gate.Pad(ctx, pad, 0, 0, 0)
gate = gate.Pad(ctx, 0, pad, 0, 0)
beta = beta.Pad(ctx, 0, pad, 0, 0)
}
@@ -344,10 +367,12 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
// Reshape gate and cumsum over chunk axis.
// [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs]
gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
// g_cumsum = cumsum(gate)
gCumsum := gate.CumSum(ctx)
gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx)
// Compute decay mask
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
@@ -411,60 +436,64 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
newState := state
// Process chunks and update state.
// Keep a transposed view of v and recurrent state across chunks so the
// chunk loop does not need extra transpose+contiguous nodes.
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
// state^T - permute is needed but Contiguous creates a copy
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// v'_t = k_cumdecay @ state_t
vTPrime := kCumdecayChunk.Mulmat(ctx, stateT)
// v_prime = k_cumdecay @ state
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
// v_new = v - v_prime
vNew := vChunk.Sub(ctx, vPrime)
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// v_t_new = v_t - v'_t
vTNewChunk := vTChunk.Sub(ctx, vTPrime)
// attn_inter = (q * g_exp) @ state
qGExp := qChunk.Mul(ctx, gExpChunk)
attnInter := stateT.Mulmat(ctx, qGExp)
// core_attn_out = attn_inter + attn @ v_new
vAttn := vNewT.Mulmat(ctx, attnChunk)
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
if coreAttnOut == nil {
coreAttnOut = coreAttnOutChunk
} else {
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
v = v.SetInplace(
ctx,
coreAttnOutChunk,
v.Stride(1),
v.Stride(2),
v.Stride(3),
chunk*v.Stride(2),
)
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// kgdmulvnew = key_gdiff_t @ v_new_t
kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk)
// state = state * g_last + kgdmulvnew
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
newState = newState.Mul(ctx, gExpLastReshaped)
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
// stateT = stateT * g_last + kgdmulvnew
stateT = stateT.Mul(ctx, gExpLastChunk)
stateT = stateT.Add(ctx, kgdMulVNew)
}
// Final reshape
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
// Slice to remove padding
if pad > 0 {
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs]
newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))

View File

@@ -1,9 +1,12 @@
package qwen3next
import (
"bytes"
"cmp"
"fmt"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@@ -11,6 +14,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/models/qwen3vl"
"github.com/ollama/ollama/tokenizer"
)
@@ -41,10 +45,15 @@ type Options struct {
ssmNGroup int // num_k_heads
ssmDtRank int // num_v_heads
convKernelSize int // SSM conv kernel size
vHeadReordered bool
// Per-layer type from GGUF metadata
isRecurrent []bool
// RoPE mode config (used by qwen35/qwen35moe)
mropeSections []int
mropeInterleaved bool
// Pre-computed masks for chunked attention (created once per forward pass)
masks *Masks
}
@@ -54,7 +63,17 @@ func (o Options) headDim() int {
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
var opts []func(*rope.Options)
if len(o.mropeSections) > 0 {
if o.mropeInterleaved {
opts = append(opts, rope.WithInterleaveMRoPE(o.mropeSections))
} else {
opts = append(opts, rope.WithMRoPE(o.mropeSections))
}
} else {
opts = append(opts, rope.WithTypeNeoX())
}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
@@ -214,20 +233,190 @@ type Model struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
Layers []Layer `gguf:"blk"`
Vision *qwen3vl.VisionModel `gguf:"v"`
ImageProcessor *qwen3vl.ImageProcessor
*Options
positionCache []int32
imageToken int32
visionStart int32
visionEnd int32
spatialMergeSize uint32
}
func (m *Model) mapPosition(id int32) int32 {
if id < int32(len(m.positionCache)) {
return m.positionCache[id]
}
if len(m.positionCache) > 0 {
return id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
}
return id
}
func (m *Model) buildPositions(ctx ml.Context, batch input.Batch) ml.Tensor {
if len(m.mropeSections) == 0 {
return ctx.Input().FromInts(batch.Positions, len(batch.Positions))
}
// ggml MRoPE expects [time, height, width, extra] for each token.
positionSlice := [][]int32{
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
}
for i, id := range batch.Positions {
p := m.mapPosition(id)
positionSlice[0][i] = p
positionSlice[1][i] = p
positionSlice[2][i] = p
}
if m.Vision != nil {
for _, mi := range batch.Multimodal {
grid, ok := mi.Multimodal[0].Data.(*qwen3vl.Grid)
if !ok {
continue
}
w := max(1, grid.Width/int(m.spatialMergeSize))
for i := range mi.Multimodal[0].Tensor.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
return ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if m.Vision == nil || m.ImageProcessor == nil || len(m.Vision.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelValues, grid, err := m.ImageProcessor.ProcessImage(ctx, img)
if err != nil {
return nil, err
}
visionOutputs, deepstackVisualEmbeds := m.Vision.Forward(ctx, pixelValues, grid)
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
for i := range deepstackVisualEmbeds {
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
}
return mm, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
m.positionCache = m.positionCache[:0]
var result []*input.Input
appendInput := func(inp *input.Input, position int32) {
result = append(result, inp)
m.positionCache = append(m.positionCache, position)
}
var p int32
for _, inp := range inputs {
if inp.Multimodal == nil {
appendInput(inp, p)
p++
continue
}
grid := inp.Multimodal[0].Data.(*qwen3vl.Grid)
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
appendInput(&input.Input{
Token: m.visionStart,
SameBatch: tokensPerGrid + 1,
}, p)
p++
appendInput(&input.Input{
Token: m.imageToken,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
}, p)
for range tokensPerGrid - 1 {
appendInput(&input.Input{
Token: m.imageToken,
}, p)
}
gridSpan := max(grid.Width/int(m.spatialMergeSize), grid.Height/int(m.spatialMergeSize))
p = p + int32(gridSpan)
appendInput(&input.Input{
Token: m.visionEnd,
}, p)
p++
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
positions := m.buildPositions(ctx, batch)
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if len(batch.Multimodal) > 0 {
hiddenStates = hiddenStates.Duplicate(ctx)
var deepstackVisualEmbeds []ml.Tensor
for _, mi := range batch.Multimodal {
visionOutputs := mi.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
}
for i, mm := range mi.Multimodal[1:] {
if deepstackVisualEmbeds[i] == nil {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
}
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
}
}
cache := m.Cache.(*HybridCache)
m.Options.masks = nil
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
if err != nil {
return nil, err
}
if i < len(deepstackVisualEmbeds) {
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
cache := m.Cache.(*HybridCache)
// Create masks once per forward pass
m.Options.masks = createMasks(ctx)
// Masks are allocated lazily only for chunked recurrent prefill.
m.Options.masks = nil
for i, layer := range m.Layers {
cache.SetLayer(i)
@@ -249,10 +438,17 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
if len(m.mropeSections) > 0 {
shift = shift.Repeat(ctx, 1, 4).Reshape(ctx, -1)
}
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
var (
_ model.Model = (*Model)(nil)
_ model.MultimodalProcessor = (*Model)(nil)
)
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
@@ -303,6 +499,22 @@ func New(c fs.Config) (model.Model, error) {
}
}
mropeSections := c.Ints("mrope_sections", nil)
if len(mropeSections) == 0 {
mropeSections = c.Ints("rope.mrope_section", nil)
}
if len(mropeSections) == 0 {
mropeSections = c.Ints("rope.dimension_sections", nil)
}
if len(mropeSections) > 4 {
mropeSections = mropeSections[:4]
}
ropeType := c.String("rope.scaling.type")
if ropeType == "" {
ropeType = c.String("rope.type")
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
@@ -318,7 +530,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeType: ropeType,
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
@@ -331,7 +543,16 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections {
if !yield(int(section)) {
return
}
}
}),
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
@@ -353,6 +574,19 @@ func New(c fs.Config) (model.Model, error) {
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
}
var vision *qwen3vl.VisionModel
var imageProcessor *qwen3vl.ImageProcessor
if c.Uint("vision.block_count", 0) > 0 {
vision = qwen3vl.NewVisionModel(c)
processor := qwen3vl.NewImageProcessor(c)
imageProcessor = &processor
}
spatialMergeSize := c.Uint("vision.spatial_merge_size", 2)
if spatialMergeSize == 0 {
spatialMergeSize = 2
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
@@ -371,8 +605,14 @@ func New(c fs.Config) (model.Model, error) {
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
Layers: layers,
Vision: vision,
ImageProcessor: imageProcessor,
Options: opts,
imageToken: int32(c.Uint("image_token_id", 151655)),
visionStart: int32(c.Uint("vision_start_token_id", 151652)),
visionEnd: int32(c.Uint("vision_end_token_id", 151653)),
spatialMergeSize: spatialMergeSize,
}
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
@@ -380,5 +620,7 @@ func New(c fs.Config) (model.Model, error) {
}
func init() {
model.Register("qwen35", New)
model.Register("qwen35moe", New)
model.Register("qwen3next", New)
}

View File

@@ -0,0 +1,101 @@
package qwen3next
import (
"testing"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/models/qwen3vl"
)
type fakeTensor struct {
*ggml.Tensor
dims []int
}
func (t *fakeTensor) Dim(i int) int {
return t.dims[i]
}
func makeImageInput(hash uint64, width, height, tokens int) *input.Input {
return &input.Input{
Multimodal: []input.Multimodal{{
Tensor: &fakeTensor{dims: []int{1, tokens, 1, 1}},
Data: &qwen3vl.Grid{Width: width, Height: height},
}},
MultimodalHash: hash,
}
}
func TestPostTokenizeMultiImageSpans(t *testing.T) {
m := &Model{
imageToken: 10,
visionStart: 11,
visionEnd: 12,
spatialMergeSize: 2,
}
inputs := []*input.Input{
{Token: 100},
makeImageInput(1, 8, 4, 4),
makeImageInput(2, 4, 8, 4),
{Token: 200},
}
got, err := m.PostTokenize(inputs)
if err != nil {
t.Fatalf("PostTokenize() error = %v", err)
}
want := []struct {
token int32
hash uint64
sameBatch int
hasMM bool
}{
{token: 100},
{token: 11, sameBatch: 5},
{token: 10, hash: 1, hasMM: true},
{token: 10},
{token: 10},
{token: 10},
{token: 12},
{token: 11, sameBatch: 5},
{token: 10, hash: 2, hasMM: true},
{token: 10},
{token: 10},
{token: 10},
{token: 12},
{token: 200},
}
if len(got) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i].Token != want[i].token {
t.Fatalf("got[%d].Token = %d, want %d", i, got[i].Token, want[i].token)
}
if got[i].MultimodalHash != want[i].hash {
t.Fatalf("got[%d].MultimodalHash = %d, want %d", i, got[i].MultimodalHash, want[i].hash)
}
if got[i].SameBatch != want[i].sameBatch {
t.Fatalf("got[%d].SameBatch = %d, want %d", i, got[i].SameBatch, want[i].sameBatch)
}
hasMM := len(got[i].Multimodal) > 0
if hasMM != want[i].hasMM {
t.Fatalf("got[%d].hasMM = %v, want %v", i, hasMM, want[i].hasMM)
}
}
wantPositions := []int32{0, 1, 2, 2, 2, 2, 6, 7, 8, 8, 8, 8, 12, 13}
if len(m.positionCache) != len(wantPositions) {
t.Fatalf("len(positionCache) = %d, want %d", len(m.positionCache), len(wantPositions))
}
for i := range wantPositions {
if m.positionCache[i] != wantPositions[i] {
t.Fatalf("positionCache[%d] = %d, want %d", i, m.positionCache[i], wantPositions[i])
}
}
}

View File

@@ -24,8 +24,8 @@ type ImageProcessor struct {
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
// NewImageProcessor creates a new image processor with default values.
func NewImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))

View File

@@ -56,60 +56,46 @@ var (
tokenVisionEnd int32 = 151653
)
type modelInput struct {
*input.Input
position int32
}
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
m.positionCache = m.positionCache[:0]
return slices.Collect(func(yield func(*input.Input) bool) {
for i := range inputs {
s := []modelInput{{Input: inputs[i]}}
if mm := inputs[i].Multimodal; mm != nil {
t := mm[0].Tensor
s = slices.Repeat([]modelInput{
{
position: int32(i + 1),
Input: &input.Input{Token: tokenVision},
},
}, t.Dim(1)+1+1)
var result []*input.Input
appendInput := func(inp *input.Input, position int32) {
result = append(result, inp)
m.positionCache = append(m.positionCache, position)
}
s[0] = modelInput{
Input: &input.Input{Token: tokenVisionStart},
position: int32(i),
}
s[len(s)-1] = modelInput{
Input: &input.Input{Token: tokenVisionEnd},
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
}
s[1] = modelInput{
Input: &input.Input{
Token: tokenVision,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1),
},
position: int32(i + 1),
}
}
for _, e := range s {
position := e.position
if position == 0 && len(m.positionCache) > 0 {
position = m.positionCache[len(m.positionCache)-1] + 1
}
m.positionCache = append(m.positionCache, position)
if !yield(e.Input) {
return
}
}
var p int32
for _, inp := range inputs {
if inp.Multimodal == nil {
appendInput(inp, p)
p++
continue
}
}), nil
grid := inp.Multimodal[0].Data.(*Grid)
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
appendInput(&input.Input{Token: tokenVisionStart}, p)
p++
appendInput(&input.Input{
Token: tokenVision,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: tokensPerGrid,
}, p)
for range tokensPerGrid - 1 {
appendInput(&input.Input{Token: tokenVision}, p)
}
p = p + int32(grid.Width/m.spatialMergeSize)
appendInput(&input.Input{Token: tokenVisionEnd}, p)
p++
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
@@ -143,9 +129,13 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
}
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
}
for i, mm := range mi.Multimodal[1:] {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
if deepstackVisualEmbeds[i] == nil {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
}
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
}
}
@@ -189,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
VisionModel: NewVisionModel(c),
ImageProcessor: NewImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {

View File

@@ -238,8 +238,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
return hiddenStates, deepstackStates
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
// NewVisionModel creates a new instance of the Qwen vision model.
func NewVisionModel(c fs.Config) *VisionModel {
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),

View File

@@ -49,6 +49,8 @@ func ParserForName(name string) Parser {
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3.5":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -59,6 +59,7 @@ func TestBuiltInParsersStillWork(t *testing.T) {
{"qwen3-coder"},
{"lfm2"},
{"lfm2-thinking"},
{"qwen3.5"},
{"harmony"},
}

View File

@@ -204,6 +204,24 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.maybeThinkingOpenAtBOL = false
}
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventThinkingContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
@@ -215,7 +233,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen

View File

@@ -145,3 +145,88 @@ func TestQwen3ParserToolCall(t *testing.T) {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "Let me think" {
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "Let me think" || len(calls) != 0 {
t.Fatalf(
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
"",
"Let me think",
0,
content,
thinking,
len(calls),
)
}
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen35ParserRespectsNoThink(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Hello! How can I help you today?", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Hello! How can I help you today?" {
t.Fatalf("expected content %q, got %q", "Hello! How can I help you today?", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}

View File

@@ -180,7 +180,22 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
return events, false
}
case CollectingThinkingContent:
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
acc := p.buffer.String()
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
toolOpenIdx := strings.Index(acc, toolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventThinkingContent{content: before})
}
p.state = CollectingToolContent
return events, true
}
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
@@ -191,13 +206,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
p.state = CollectingContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
@@ -205,11 +220,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {

View File

@@ -98,8 +98,12 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
desc: "nested thinking and tool call (outside thinking, inside tool call)",
steps: []step{
{
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm nested tool call"},
qwenEventContent{content: "</think>"},
},
},
},
},
@@ -109,8 +113,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
qwenEventContent{content: "</tool_call>"},
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
},
},
},
@@ -121,8 +124,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
qwenEventContent{content: "</tool_call>"},
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
qwenEventContent{content: "</think>"},
},

View File

@@ -3,6 +3,7 @@ package renderers
import (
"bytes"
"encoding/json"
"fmt"
"sort"
"strings"
@@ -192,21 +193,25 @@ func lfm2RenderToolCalls(calls []api.ToolCall) string {
return sb.String()
}
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int) string {
content := lfm2RenderContent(message.Content, r.useImgTags)
if len(message.Images) == 0 {
return content
}
// chatPrompt may already have inserted [img] / [img-n] placeholders.
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
return content
}
var sb strings.Builder
placeholder := lfm2ImagePlaceholder(r.useImgTags)
for range message.Images {
sb.WriteString(placeholder)
if r.useImgTags {
for i := range message.Images {
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
}
} else {
placeholder := lfm2ImagePlaceholder(false)
if strings.Contains(content, placeholder) {
return content
}
for range message.Images {
sb.WriteString(placeholder)
}
}
sb.WriteString(content)
return sb.String()
@@ -262,6 +267,11 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
}
}
imageOffset := 0
for i := range startIdx {
imageOffset += len(messages[i].Images)
}
for i := startIdx; i < len(messages); i++ {
message := messages[i]
lastMessage := i == len(messages)-1
@@ -271,7 +281,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
sb.WriteString(message.Role)
sb.WriteString("\n")
content := r.renderMessageContent(message)
content := r.renderMessageContent(message, imageOffset)
imageOffset += len(message.Images)
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
content = strings.TrimSpace(content[idx+len("</think>"):])

View File

@@ -236,16 +236,6 @@ func TestLFM2Renderer_Images(t *testing.T) {
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_indexed_img_placeholder_not_duplicated",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "[img-0]Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{

View File

@@ -1,6 +1,7 @@
package renderers
import (
"fmt"
"strings"
"github.com/ollama/ollama/api"
@@ -9,10 +10,11 @@ import (
type Qwen3VLRenderer struct {
isThinking bool
useImgTags bool
emitEmptyThinkOnNoThink bool
useImgTags bool
}
func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
var subSb strings.Builder
for range content.Images {
@@ -20,7 +22,8 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
// model backends, and so we should eventually parameterize this or
// only output a placeholder such as [img]
if r.useImgTags {
subSb.WriteString("[img]")
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
imageOffset++
} else {
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
}
@@ -28,12 +31,17 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
// TODO: support videos
subSb.WriteString(content.Content)
return subSb.String()
return subSb.String(), imageOffset
}
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
var sb strings.Builder
isThinking := r.isThinking
if think != nil {
isThinking = think.Bool()
}
if len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
if len(messages) > 0 && messages[0].Role == "system" {
@@ -57,7 +65,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
message := messages[i]
if multiStepTool && message.Role == "user" {
// Check if content starts with <tool_response> and ends with </tool_response>
content := r.renderContent(message)
content, _ := r.renderContent(message, 0)
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
multiStepTool = false
lastQueryIndex = i
@@ -65,8 +73,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
}
}
imageOffset := 0
for i, message := range messages {
content := r.renderContent(message)
content, nextImageOffset := r.renderContent(message, imageOffset)
imageOffset = nextImageOffset
lastMessage := i == len(messages)-1
prefill := lastMessage && message.Role == "assistant"
@@ -76,13 +86,13 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
} else if message.Role == "assistant" {
contentReasoning := ""
if r.isThinking {
if isThinking {
if message.Thinking != "" {
contentReasoning = message.Thinking
}
}
if r.isThinking && i > lastQueryIndex {
if isThinking && i > lastQueryIndex {
if i == len(messages)-1 || contentReasoning != "" {
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
if content != "" {
@@ -125,8 +135,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
// prefill at the end
if lastMessage && !prefill {
sb.WriteString("<|im_start|>assistant\n")
if r.isThinking {
if isThinking {
sb.WriteString("<think>\n")
} else if r.emitEmptyThinkOnNoThink {
sb.WriteString("<think>\n\n</think>\n\n")
}
}
}

View File

@@ -101,7 +101,7 @@ Let me analyze this image.`,
},
useImgTags: true,
expected: `<|im_start|>user
[img]Describe this image.<|im_end|>
[img-0]Describe this image.<|im_end|>
<|im_start|>assistant
Let me analyze this image.`,
},
@@ -123,7 +123,7 @@ Let me analyze this image.`,
},
useImgTags: true,
expected: `<|im_start|>user
[img][img]Describe these images.<|im_end|>
[img-0][img-1]Describe these images.<|im_end|>
<|im_start|>assistant
Let me analyze this image.`,
},

View File

@@ -1,6 +1,7 @@
package renderers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -370,3 +371,74 @@ func TestFormatToolCallArgumentThinkingVL(t *testing.T) {
})
}
}
func TestQwen3VLRendererThinkOverride(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
renderThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderThinking, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected default thinking renderer to emit <think>, got:\n%s", renderThinking)
}
renderNonThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, &api.ThinkValue{Value: false})
if err != nil {
t.Fatal(err)
}
if strings.Contains(renderNonThinking, "<think>") {
t.Fatalf("expected think=false override to suppress <think>, got:\n%s", renderNonThinking)
}
renderForcedThinking, err := (&Qwen3VLRenderer{isThinking: false}).Render(msgs, nil, &api.ThinkValue{Value: true})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderForcedThinking, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected think=true override to emit <think>, got:\n%s", renderForcedThinking)
}
}
func TestQwen3VLRendererThinkOverrideWithExplicitNoThinkPrefill(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
renderNonThinking, err := (&Qwen3VLRenderer{
isThinking: true,
emitEmptyThinkOnNoThink: true,
}).Render(msgs, nil, &api.ThinkValue{Value: false})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(renderNonThinking, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected explicit think=false prefill block, got:\n%s", renderNonThinking)
}
}
func TestQwenRendererNameNoThinkBehaviorSplit(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "Hello"},
}
thinkFalse := &api.ThinkValue{Value: false}
qwen35Rendered, err := RenderWithRenderer("qwen3.5", msgs, nil, thinkFalse)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(qwen35Rendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected qwen3.5 renderer to emit explicit no-think prefill, got:\n%s", qwen35Rendered)
}
qwen3VLRendered, err := RenderWithRenderer("qwen3-vl-thinking", msgs, nil, thinkFalse)
if err != nil {
t.Fatal(err)
}
if strings.Contains(qwen3VLRendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected qwen3-vl-thinking renderer to keep legacy non-empty no-think behavior, got:\n%s", qwen3VLRendered)
}
}

View File

@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
case "qwen3-vl-thinking":
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
return renderer
case "qwen3.5":
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
return renderer
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
return renderer

View File

@@ -29,17 +29,27 @@ func TestRegisterCustomRenderer(t *testing.T) {
}
func TestBuiltInRendererStillWorks(t *testing.T) {
// Test that qwen3-coder still works
tests := []struct {
name string
}{
{name: "qwen3-coder"},
{name: "qwen3.5"},
}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Error("expected non-empty result from qwen3-coder renderer")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := RenderWithRenderer(tt.name, messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Fatalf("expected non-empty result from %s renderer", tt.name)
}
})
}
}

View File

@@ -86,6 +86,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
ID: len(images),
Data: i,
}
images = append(images, imgData)
if m.Config.Renderer != "" {
continue
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
@@ -93,8 +98,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = prefix + prompt
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestChatPrompt(t *testing.T) {
@@ -330,3 +331,38 @@ func TestChatPromptTokenizeCalls(t *testing.T) {
})
}
}
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "what do these photos have in common?",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
},
}
originalContent := msgs[0].Content
m := Model{
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if msgs[0].Content != originalContent {
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
}
if got, want := len(images), 3; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if prompt == "" {
t.Fatal("prompt is empty")
}
}

View File

@@ -6,6 +6,7 @@ import (
"log/slog"
"maps"
"os"
"slices"
"strings"
"unsafe"
@@ -33,6 +34,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
@@ -58,7 +62,7 @@ func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
switch {
// Full attention
case strings.HasSuffix(name, ".attn_q.weight"):
@@ -79,6 +83,10 @@ func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
// SSM
case strings.HasSuffix(name, ".ssm_ba.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_beta.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_alpha.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_out.weight"):
return fsggml.TensorTypeQ4_K, true
@@ -287,8 +295,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
newType := fsggml.TensorType(t.Kind)
if quantize {
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3nextQuantType(name); ok {
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3LinearAttnQuantType(name); ok {
return qt
}
}

View File

@@ -166,6 +166,60 @@ func TestGetTensorNewType(t *testing.T) {
}
}
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
cases := []struct {
name string
arch string
tensor string
fileType fsggml.FileType
expected fsggml.TensorType
}{
{
name: "qwen35_beta",
arch: "qwen35",
tensor: "blk.0.ssm_beta.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35_alpha",
arch: "qwen35",
tensor: "blk.0.ssm_alpha.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35moe_attn_qkv",
arch: "qwen35moe",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "non_qwen35_falls_back",
arch: "foo",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
kv := fsggml.KV{"general.architecture": tt.arch}
got := newType(&fsggml.Tensor{
Name: tt.tensor,
Shape: []uint64{256, 256},
Kind: uint32(fsggml.TensorTypeF16),
}, kv, &quantizeState{}, tt.fileType)
if got != tt.expected {
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
}
})
}
}
func TestQuantizeModel(t *testing.T) {
cases := []struct {
name string
@@ -173,6 +227,7 @@ func TestQuantizeModel(t *testing.T) {
tensors []*fsggml.Tensor
newType string
expectedTensorTypes map[string]fsggml.TensorType
expectErr bool
}{
{
name: "f16_q4_k",
@@ -253,6 +308,36 @@ func TestQuantizeModel(t *testing.T) {
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "f32_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
{
name: "f16_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
}
for _, tt := range cases {
@@ -264,6 +349,9 @@ func TestQuantizeModel(t *testing.T) {
}
defer fp.Close()
meta, err := fsggml.Decode(fp, -1)
if tt.expectErr && err != nil {
return
}
if err != nil {
t.Fatal(err.Error())
}
@@ -283,6 +371,12 @@ func TestQuantizeModel(t *testing.T) {
}
err = quantize(fp, tmp, meta, ftype, progress)
if tt.expectErr {
if err == nil {
t.Fatal("expected quantize to return an error")
}
return
}
if err != nil {
t.Fatalf("error during quantize: %s", err)
}

View File

@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

View File

@@ -9,59 +9,104 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Tokens []int32
Caches []cache.Cache
type kvCache struct {
// For now we only support a single entry, so this is just one sequence
tokens []int32
caches []cache.Cache
}
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
if len(c.caches) == 0 {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
}
// Find longest common prefix
remaining := c.findRemaining(inputs)
return &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if offset := s.caches[0].Offset(); offset > 0 {
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
}
mlx.AsyncEval(arrays...)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
}
}
// findRemaining finds the longest common prefix between tokens and the cached
// sequence, trims stale cache entries, and returns the remaining tokens.
func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
prefix++
}
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
if prefix == len(tokens) && prefix > 0 {
// Leave one token to run through the model so we can sample a response.
prefix--
}
if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix
for _, kv := range c.caches {
kv.Trim(trim)
}
r.cache = nil
c.tokens = c.tokens[:prefix]
}
if prefix == 0 {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
} else {
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
return tokens[prefix:]
}
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
func (c *kvCache) log() {
if len(c.caches) == 0 {
return
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
for _, kv := range c.caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() {
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.log()
}
}()
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
@@ -30,22 +46,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
caches := session.caches
tokens := session.remaining
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
if err := request.Ctx.Err(); err != nil {
return err
}
n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
@@ -76,15 +89,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
if err := request.Ctx.Err(); err != nil {
return err
}
nextSample, nextLogprobs = step(sample)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -94,43 +110,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
output := int32(sample.Int())
outputs = append(outputs, output)
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}:
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- final:
return nil
}
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {

View File

@@ -12,7 +12,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -25,8 +24,9 @@ type Request struct {
Responses chan Response
Pipeline func(Request) error
Ctx context.Context
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
@@ -61,7 +61,7 @@ type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
cache kvCache
}
func (r *Runner) Load(modelName string) error {
@@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
slog.Info("Request terminated", "error", err)
}
close(request.Responses)

View File

@@ -5,6 +5,7 @@ package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
"fmt"
@@ -98,19 +99,36 @@ func Execute(args []string) error {
request.Options.TopK,
)
runner.Requests <- request
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()
select {
case <-r.Context().Done():
return
case runner.Requests <- request:
}
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
for {
select {
case <-r.Context().Done():
return
}
case response, ok := <-request.Responses:
if !ok {
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
})