Compare commits

...

36 Commits

Author SHA1 Message Date
ParthSareen
52f757d8a2 cmd: fix gofmt formatting in pi integration
Fixes formatting issues caught by golangci-lint.
2026-02-04 19:27:34 -08:00
ParthSareen
86aa7cd0a6 cmd: add pi integration to ollama launch
Adds support for launching Pi (Pi Coding Agent) via ollama launch.
Pi is a CLI tool distributed via npm that can use Ollama models.

Features:
- Configures ~/.pi/agent/models.json with Ollama provider settings
- Sets defaultProvider and defaultModel in ~/.pi/agent/settings.json
- Uses _launch marker to identify ollama-managed models
- Preserves user-managed models and other config fields
- Backups both files before modification
- Hidden from interactive selector (use ollama launch pi)

Tested with 20 comprehensive tests including safety tests
for user config preservation.

Usage:
  ollama launch pi
  ollama launch pi --model glm-4.7-flash
2026-02-04 18:51:40 -08:00
Jeffrey Morgan
255579aaa7 qwen3next: fix issue in delta net (#14075)
gDiffExp was being broadcast across the wrong axis when multiplying with k. This fix reshapes gDiffExp to [1, chunkSize, nChunks, ...]
2026-02-04 13:40:38 -08:00
Jeffrey Morgan
f7102ba826 runner: discard compute results if sequence replaced mid-batch (#14072)
If a sequence is replaced in s.seqs while a batch is computing, the old logits can be decoded into the new sequence. This change rechecks the sequence pointer after compute and skips decoding for replaced entries, preventing stale results from being applied.
2026-02-04 13:19:48 -08:00
Jeffrey Morgan
cefabd79a8 Revert "cmd: claude launch improvements (#14064)" (#14071)
This reverts commit ee25219edd.
2026-02-04 09:10:37 -08:00
Jeffrey Morgan
df70249520 server: optimize chatPrompt to reduce tokenization calls (#14040)
Change the truncation algorithm to start with all messages and remove
from the front until it fits, rather than adding messages one at a time
from the back. This reduces tokenization calls from O(n) to O(1) in the
common case where all messages fit in context.
2026-02-04 01:21:31 -08:00
Jeffrey Morgan
77eb2ca619 model: add qwen3-next architecture (#14051) 2026-02-03 23:27:21 -08:00
Parth Sareen
ee25219edd cmd: claude launch improvements (#14064) 2026-02-03 19:33:58 -08:00
Jeffrey Morgan
b1fccabb34 Revert "Update vendored llama.cpp to b7847" (#14061) 2026-02-03 18:39:36 -08:00
Bruce MacDonald
a6355329bf cmd: open browser on ollama signin when available (#14055)
When a browser is available open it to the connect URL automatically when running the `ollama signin` command. Browser is not opened in any other unauthorized scenario.
2026-02-03 16:42:09 -08:00
Parth Sareen
0398b24b42 cmd: launch defaults (#14035) 2026-02-02 23:19:11 -08:00
Parth Sareen
75b1dddf91 cmd: launch extra params (#14039) 2026-02-03 02:03:33 -05:00
Parth Sareen
e1e80ffc3e cmd/config: move config location (#14034) 2026-02-02 22:48:51 -05:00
Aleksandr Vukmirovich
71896485fd anthropic: add InputTokens to streaming response (#13934)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-02-02 18:29:37 -08:00
Jeffrey Morgan
ef00199fb4 Update vendor ggml code to a5bb8ba4 (#13832)
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Shalini Salomi Bodapati <Shalini.Salomi.Bodapati@ibm.com>
2026-02-02 17:31:59 -08:00
Jeffrey Morgan
8f4a008139 Add GLM-OCR vision model support (#14024) 2026-02-02 15:39:18 -08:00
Patrick Devine
d8cc798c2b glm 4.7 flash support on experimental engine (#13838) 2026-02-02 15:22:11 -08:00
Richard Lyons
6582f6da5c llm: Make "do load request" error message more informative 2026-02-02 11:13:21 -08:00
Jesse Gross
0334ffa625 server: use tiered VRAM-based default context length
Replace binary low VRAM mode with tiered VRAM thresholds that set
default context lengths for all models:

- < 24 GiB VRAM: 4,096 context
- 24-48 GiB VRAM: 32,768 context
- >= 48 GiB VRAM: 262,144 context
2026-02-02 10:47:09 -08:00
Jesse Gross
d11fbd2c60 server: fix ollama ps showing configured instead of actual context length
When context length is clamped to the model's trained context length,
ollama ps now shows the actual clamped value instead of the originally
configured value.
2026-02-02 10:47:09 -08:00
Jeffrey Morgan
6a7c3f188e openclaw: run onboarding for fresh installs (#14006)
When launching OpenClaw without prior onboarding, run the onboarding
wizard instead of going straight to gateway. This ensures proper
gateway configuration (mode, token, etc.) before first use.

- Add onboarded() to check for wizard.lastRunAt marker in config
- Run onboard with --auth-choice skip --gateway-token ollama for fresh installs
- Existing installs (onboarding completed) run gateway directly
2026-02-01 13:46:45 -08:00
Jeffrey Morgan
427e2c962a docs: add redirect from clawdbot to openclaw (#14004) 2026-01-31 20:50:42 -08:00
Thanh Nguyen
27db7f806f cmd/config: rename integration to openclaw (#13979)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-01-31 18:31:13 -05:00
Dhiraj Lochib
3590fbfa76 runner: fix typo 'baackend' -> 'backend' in error messages (#13645)
Fix typo in three error messages where 'baackend' was written instead
of 'backend' in the /health endpoint handler when initializing the
dummy model load.
2026-01-31 13:26:20 -08:00
noureldin-azzab
cd0094f772 added stakpak to web & desktop (#13961) 2026-01-31 13:04:34 -08:00
Louis Beaumont
06bc8e6712 docs: add Screenpipe to Community Integrations (#13906)
Screenpipe is a 24/7 screen & mic recording tool that uses Ollama
for local LLM-powered search and AI features. 16k+ GitHub stars.
2026-01-31 12:49:52 -08:00
frob
fc5f9bb448 docs: remove unsupported quantizations (#13982) 2026-01-31 12:46:20 -08:00
frob
a0740f7ef7 docs: add GB10 to supported devices (#13987) 2026-01-31 12:45:27 -08:00
Parth Sareen
a0923cbdd0 cmd: ollama launch add placeholder text for selector (#13966) 2026-01-29 09:48:49 -08:00
Seokrin Taron Sung
f92e362b2e cmd: capitalize Ollama in serve command help text (#13965) 2026-01-29 09:47:53 -08:00
Tincho
aa23d8ecd2 docs: update installation command for OpenCode CLI (#13971) 2026-01-29 09:47:02 -08:00
Gabe Goodhart
7b62c41060 cmd/config: use envconfig.Host() for base API in launch config packages (#13937) 2026-01-27 13:30:00 -08:00
Parth Sareen
26acab64b7 docs: add clawdbot (#13925) 2026-01-26 18:32:54 -08:00
Gyungrai Wang
e0f03790b1 parsers/ministral: fix nested tool call parsing by counting brace nesting (#13905)
* parsers/ministral: fix nested tool call parsing by counting brace nesting

* fix lint error

* parsers: refactor ministral parser

The old one was very tied to expecting to see only one token at a time,
which I don't like to assume (who knows what the future might hold wrt
speculative decoding, etc). This new one follows a similar structure to
qwen3-coder's parser, which incidentally makes it easier to test as well
(since we can test the individual events that come out when given
particular inputs).

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-26 15:03:43 -08:00
Parth Sareen
3ab842b0f5 cmd: clawdbot config fixes (#13922) 2026-01-26 14:34:29 -08:00
Parth Sareen
b8e8ef8929 cmd: ollama launch clawdbot (#13921) 2026-01-26 13:40:59 -08:00
114 changed files with 13029 additions and 2665 deletions

View File

@@ -358,6 +358,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
@@ -465,6 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
### Cloud

3
anthropic/anthropic.go Normal file → Executable file
View File

@@ -211,6 +211,7 @@ type MessageDelta struct {
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
@@ -721,6 +722,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
})
}
c.inputTokens = r.Metrics.PromptEvalCount
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
@@ -732,6 +734,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
StopReason: stopReason,
},
Usage: DeltaUsage{
InputTokens: c.inputTokens,
OutputTokens: c.outputTokens,
},
},

20
anthropic/anthropic_test.go Normal file → Executable file
View File

@@ -642,7 +642,7 @@ func TestStreamConverter_Basic(t *testing.T) {
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events2 := conv.Process(resp2)
@@ -650,6 +650,24 @@ func TestStreamConverter_Basic(t *testing.T) {
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_delta" {
if data, ok := e.Data.(MessageDeltaEvent); ok {
if data.Type != "message_delta" {
t.Errorf("unexpected data type: %+v", data)
}
if data.Delta.StopReason != "end_turn" {
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
}
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", data.Usage)
}
} else {
t.Errorf("unexpected data: %+v", e.Data)
}
}
if e.Event == "message_stop" {
hasStop = true
}

View File

@@ -29,6 +29,7 @@ import (
"github.com/containerd/console"
"github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter"
"github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
@@ -52,7 +53,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -663,6 +664,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
fmt.Println()
if aErr.SigninURL != "" {
_ = browser.OpenURL(aErr.SigninURL)
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
@@ -1888,7 +1890,7 @@ func NewCLI() *cobra.Command {
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
Short: "Start ollama",
Short: "Start Ollama",
Args: cobra.ExactArgs(0),
RunE: RunServer,
}

View File

@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +

View File

@@ -6,6 +6,8 @@ import (
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
@@ -13,11 +15,13 @@ type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
return []string{"--model", model}
args = append(args, "--model", model)
}
return nil
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
@@ -39,18 +43,18 @@ func (c *Claude) findPath() (string, error) {
return fallback, nil
}
func (c *Claude) Run(model string) error {
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model)...)
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)

View File

@@ -84,17 +84,21 @@ func TestClaudeArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}

View File

@@ -14,20 +14,21 @@ type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string) error {
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
@@ -20,6 +21,14 @@ type config struct {
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config.json"), nil
}
func legacyConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
@@ -27,6 +36,46 @@ func configPath() (string, error) {
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
func migrateConfig() (bool, error) {
oldPath, err := legacyConfigPath()
if err != nil {
return false, err
}
oldData, err := os.ReadFile(oldPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
newPath, err := configPath()
if err != nil {
return false, err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return false, err
}
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
return false, fmt.Errorf("write new config: %w", err)
}
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
@@ -34,6 +83,11 @@ func load() (*config, error) {
}
data, err := os.ReadFile(path)
if err != nil && os.IsNotExist(err) {
if migrated, merr := migrateConfig(); merr == nil && migrated {
data, err = os.ReadFile(path)
}
}
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil

View File

@@ -200,12 +200,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
dir := filepath.Join(tmpDir, ".ollama")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
@@ -267,7 +265,7 @@ func TestConfigPath(t *testing.T) {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
expected := filepath.Join(tmpDir, ".ollama", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
@@ -322,6 +320,183 @@ func TestLoad(t *testing.T) {
})
}
func TestMigrateConfig(t *testing.T) {
t.Run("migrates legacy file to new location", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
newPath, _ := configPath()
got, err := os.ReadFile(newPath)
if err != nil {
t.Fatalf("new config not found: %v", err)
}
if string(got) != string(data) {
t.Errorf("content mismatch: got %s", got)
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("legacy file should have been removed")
}
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
t.Error("legacy directory should have been removed")
}
})
t.Run("no-op when no legacy file exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("expected no migration")
}
})
t.Run("skips corrupt legacy file", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("should not migrate corrupt file")
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
t.Error("corrupt legacy file should not have been deleted")
}
})
t.Run("new path takes precedence over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
newDir := filepath.Join(tmpDir, ".ollama")
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if _, ok := cfg.Integrations["new"]; !ok {
t.Error("expected new-path integration to be loaded")
}
if _, ok := cfg.Integrations["old"]; ok {
t.Error("legacy integration should not have been loaded")
}
})
t.Run("idempotent when called twice", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("second migration should be a no-op")
}
})
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
t.Error("directory with other files should not have been removed")
}
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
t.Error("other files in legacy directory should be untouched")
}
})
t.Run("save writes to new path after migration", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
// load triggers migration, then save should write to new path
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
t.Fatal(err)
}
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
if _, err := os.Stat(newPath); os.IsNotExist(err) {
t.Error("save should write to new path")
}
// old path should not be recreated
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("save should not recreate legacy path")
}
})
t.Run("load triggers migration transparently", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
t.Error("migration via load() did not preserve data")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -7,6 +7,8 @@ import (
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Droid implements Runner and Editor for Droid integration
@@ -37,7 +39,7 @@ type modelEntry struct {
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
func (d *Droid) Run(model string, args []string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
@@ -51,7 +53,7 @@ func (d *Droid) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -117,7 +119,7 @@ func (d *Droid) Edit(models []string) error {
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: "http://localhost:11434/v1",
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,

View File

@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
}
}
if model["baseUrl"] != "http://localhost:11434/v1" {
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
}
if model["apiKey"] != "ollama" {
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
{
"model": "existing-ollama-model",
"displayName": "existing-ollama-model",
"baseUrl": "http://localhost:11434/v1",
"baseUrl": "http://127.0.0.1:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 64000,

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -22,7 +23,7 @@ import (
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
Run(model string, args []string) error
// String returns the human-readable name of the integration
String() string
}
@@ -41,9 +42,29 @@ type Editor interface {
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Openclaw{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
}
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
func selectIntegration() (string, error) {
@@ -54,6 +75,9 @@ func selectIntegration() (string, error) {
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
@@ -82,62 +106,25 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
var existing []modelInfo
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
if len(items) == 0 {
return nil, fmt.Errorf("no models available")
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
@@ -151,7 +138,27 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var toPull []string
for _, m := range selected {
if !existingModels[m] {
toPull = append(toPull, m)
}
}
if len(toPull) > 0 {
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
if ok, err := confirmPrompt(msg); err != nil {
return nil, err
} else if !ok {
return nil, errCancelled
}
for _, m := range toPull {
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, m); err != nil {
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
}
}
}
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
@@ -221,13 +228,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selected, nil
}
func runIntegration(name, modelName string) error {
func runIntegration(name, modelName string, args []string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
return r.Run(modelName, args)
}
// LaunchCmd returns the cobra command for launching integrations.
@@ -236,7 +243,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
@@ -245,19 +252,43 @@ Supported integrations:
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// Extract integration name and args to pass through using -- separator
var name string
if len(args) > 0 {
name = args[0]
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
// No "--" separator: only allow 0 or 1 args (integration name)
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
// "--" was used: args before it = integration name, args after = passthrough
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
@@ -273,16 +304,14 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
return runIntegration(name, config.Models[0], passArgs)
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
@@ -337,13 +366,13 @@ Examples:
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
return runIntegration(name, models[0], passArgs)
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0])
return runIntegration(name, models[0], passArgs)
},
}
@@ -351,3 +380,154 @@ Examples:
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}
type modelInfo struct {
Name string
Remote bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
for _, rec := range recommendedModels {
recommended[rec.Name] = true
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModel(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Non-existing models get "install?" suffix and are pushed to the bottom.
// When user has no models, preserve recommended order.
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] {
notInstalled[items[i].Name] = true
if items[i].Description != "" {
items[i].Description += ", install?"
} else {
items[i].Description = "install?"
}
}
}
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if !ac && !bc && aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model}
return client.Pull(ctx, &request, fn)
}

View File

@@ -1,10 +1,12 @@
package config
import (
"fmt"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
)
@@ -90,8 +92,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -121,7 +123,7 @@ func TestLaunchCmd(t *testing.T) {
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
@@ -174,15 +176,336 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
var _ func(string, []string) error = r.Run
})
}
}
func TestParseArgs(t *testing.T) {
// Tests reflect cobra's ArgsLenAtDash() semantics:
// - cobra strips "--" from args
// - ArgsLenAtDash() returns the index where "--" was, or -1
tests := []struct {
name string
args []string // args as cobra delivers them (no "--")
dashIdx int // what ArgsLenAtDash() returns
wantName string
wantArgs []string
wantErr bool
}{
{
name: "no extra args, no dash",
args: []string{"claude"},
dashIdx: -1,
wantName: "claude",
},
{
name: "with extra args after --",
args: []string{"codex", "-p", "myprofile"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"-p", "myprofile"},
},
{
name: "extra args only after --",
args: []string{"codex", "--sandbox", "workspace-write"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"--sandbox", "workspace-write"},
},
{
name: "-- at end with no args after",
args: []string{"claude"},
dashIdx: 1,
wantName: "claude",
},
{
name: "-- with no integration name",
args: []string{"--verbose"},
dashIdx: 0,
wantName: "",
wantArgs: []string{"--verbose"},
},
{
name: "multiple args before -- is error",
args: []string{"claude", "codex", "--verbose"},
dashIdx: 2,
wantErr: true,
},
{
name: "multiple args without -- is error",
args: []string{"claude", "codex"},
dashIdx: -1,
wantErr: true,
},
{
name: "no args, no dash",
args: []string{},
dashIdx: -1,
wantName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the parsing logic from LaunchCmd using dashIdx
var name string
var parsedArgs []string
var err error
dashIdx := tt.dashIdx
args := tt.args
if dashIdx == -1 {
if len(args) > 1 {
err = fmt.Errorf("unexpected arguments: %v", args[1:])
} else if len(args) == 1 {
name = args[0]
}
} else {
if dashIdx > 1 {
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
} else {
if dashIdx == 1 {
name = args[0]
}
parsedArgs = args[dashIdx:]
}
}
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if name != tt.wantName {
t.Errorf("name = %q, want %q", name, tt.wantName)
}
if !slices.Equal(parsedArgs, tt.wantArgs) {
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
}
})
}
}
func TestIsCloudModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
}
})
}
}
func names(items []selectItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)
}
return out
}
func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
}
for _, item := range items {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
}
}
}
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "qwen2.5:latest", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
}
}
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
}
}
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
got := names(items)
if got[0] != "llama3.2" {
t.Errorf("pre-checked model should be first, got %v", got)
}
}
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
switch item.Name {
case "glm-4.7-flash", "glm-4.7:cloud":
if strings.HasSuffix(item.Description, "install?") {
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
}
case "kimi-k2.5:cloud", "qwen3:8b":
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
}
}
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "kimi-k2.5:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// kimi-k2.5:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
}
for _, item := range items {
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_LatestTagStripped(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash:latest", Remote: false},
{Name: "llama3.2:latest", Remote: false},
}
items, _, existingModels, _ := buildModelList(existing, nil, "")
got := names(items)
// :latest should be stripped from display names
for _, name := range got {
if strings.HasSuffix(name, ":latest") {
t.Errorf("name %q should not have :latest suffix", name)
}
}
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
count := 0
for _, name := range got {
if name == "glm-4.7-flash" {
count++
}
}
if count != 1 {
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
}
// Stripped name should be in existingModels so it won't be pulled
if !existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should be in existingModels")
}
}
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if !existingModels["llama3.2"] {
t.Error("llama3.2 should be in existingModels")
}
if !existingModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in existingModels")
}
if existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
}
if !cloudModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in cloudModels")
}
if !cloudModels["kimi-k2.5:cloud"] {
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
}
if cloudModels["llama3.2"] {
t.Error("llama3.2 should not be in cloudModels")
}
}

254
cmd/config/openclaw.go Normal file
View File

@@ -0,0 +1,254 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
cmd := exec.Command(bin, "onboard",
"--auth-choice", "skip",
"--gateway-token", "ollama",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Openclaw) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

878
cmd/config/openclaw_test.go Normal file
View File

@@ -0,0 +1,878 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestOpenclawIntegration(t *testing.T) {
c := &Openclaw{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "OpenClaw" {
t.Errorf("String() = %q, want %q", got, "OpenClaw")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelExists(t, configPath, "mistral")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
})
}
func TestOpenclawModels(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertOpenclawModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestOpenclawPaths(t *testing.T) {
c := &Openclaw{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestOpenclawModelsEdgeCases(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestOpenclawEditSchemaFields(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestOpenclawEditModelNames(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "library/model:tag")
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "test-model")
})
}
func TestOpenclawEditAgentsPreservation(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testOpenclawFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestOpenclawEdit_RoundTrip(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestOpenclawEdit_Idempotent(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestOpenclawEdit_BackupCreated(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestOpenclawClawdbotAlias(t *testing.T) {
for _, alias := range []string{"clawdbot", "moltbot"} {
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
r, ok := integrations[alias]
if !ok {
t.Fatalf("%s not found in integrations", alias)
}
if _, ok := r.(*Openclaw); !ok {
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
}
})
t.Run(alias+" is hidden from selector", func(t *testing.T) {
if !integrationAliases[alias] {
t.Errorf("%s should be in integrationAliases", alias)
}
})
}
}
func TestOpenclawLegacyPaths(t *testing.T) {
c := &Openclaw{}
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
t.Errorf("expected legacy path, got %s", paths[0])
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(newDir, "openclaw.json") {
t.Errorf("expected new path, got %s", paths[0])
}
})
t.Run("Models reads from legacy path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("expected [llama3.2], got %v", models)
}
})
t.Run("Models prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "new-model" {
t.Errorf("expected [new-model], got %v", models)
}
})
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "new" {
t.Errorf("expected theme from new config, got %v", cfg["theme"])
}
})
t.Run("Edit migrates from legacy config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Should write to new path
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
data, err := os.ReadFile(newPath)
if err != nil {
t.Fatal("expected new config file to be created")
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("legacy theme setting was not migrated")
}
})
}
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}
func TestOpenclawOnboarded(t *testing.T) {
c := &Openclaw{}
t.Run("returns false when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if c.onboarded() {
t.Error("expected false when no config exists")
}
})
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
if c.onboarded() {
t.Error("expected false when no wizard section")
}
})
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is missing")
}
})
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is empty")
}
})
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when wizard.lastRunAt is set")
}
})
t.Run("checks legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when legacy config has wizard.lastRunAt")
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
// New path has no wizard marker
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
// Legacy has wizard marker
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if c.onboarded() {
t.Error("expected false - should prefer new path which has no wizard marker")
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
if c.onboarded() {
t.Error("expected false for corrupted JSON")
}
})
t.Run("handles wrong type for wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard is wrong type")
}
})
}

View File

@@ -9,6 +9,8 @@ import (
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
@@ -16,7 +18,7 @@ type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
func (o *OpenCode) Run(model string, args []string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
@@ -30,7 +32,7 @@ func (o *OpenCode) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd := exec.Command("opencode", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -88,7 +90,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": "http://localhost:11434/v1",
"baseURL": envconfig.Host().String() + "/v1",
},
}
}

196
cmd/config/pi.go Normal file
View File

@@ -0,0 +1,196 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, map[string]any{
"id": model,
"_launch": true,
})
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}

609
cmd/config/pi_test.go Normal file
View File

@@ -0,0 +1,609 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}

View File

@@ -275,7 +275,11 @@ func parseInput(r io.Reader) (inputEvent, byte, error) {
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
@@ -314,7 +318,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
@@ -345,10 +353,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
suffix = " " + ansiGray + "(default)" + ansiReset
}
desc := ""
if item.Description != "" {
desc = " " + ansiGray + "- " + item.Description + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
}
lineCount++
}

View File

@@ -313,8 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

455
convert/convert_glmocr.go Normal file
View File

@@ -0,0 +1,455 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
//
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
rotaryDim := int(float32(headDim) * partialRotaryFactor)
if rotaryDim%2 != 0 {
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
}
// Handle 1D (bias) or 2D (weight) tensors
is1D := len(shape) == 1
var inFeatures int
if is1D {
inFeatures = 1
} else {
inFeatures = int(shape[1])
}
outFeatures := int(shape[0])
nEffectiveHeads := outFeatures / headDim
if nEffectiveHeads != nHeads {
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
}
// Reshape to [n_heads, head_dim, in_features]
reshaped := make([]float32, len(data))
copy(reshaped, data)
// Permute the rotary dimensions: even indices first, then odd
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
result := make([]float32, len(data))
halfRotary := rotaryDim / 2
for h := range nEffectiveHeads {
for f := range inFeatures {
for i := range halfRotary {
// Even dim (0, 2, 4, ...) -> position i
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
dstIdx := h*headDim*inFeatures + i*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
}
// Non-rotary part: copy as-is
for i := rotaryDim; i < headDim; i++ {
srcIdx := h*headDim*inFeatures + i*inFeatures + f
result[srcIdx] = reshaped[srcIdx]
}
}
}
return result, nil
}
}
type glmOcrModel struct {
ModelParameters
TextConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
RMSNormEps float32 `json:"rms_norm_eps"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeParameters struct {
RopeType string `json:"rope_type"`
MRopeSection []int32 `json:"mrope_section"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
Depth uint32 `json:"depth"`
NumHeads uint32 `json:"num_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
OutHiddenSize uint32 `json:"out_hidden_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
ImageStartTokenID uint32 `json:"image_start_token_id"`
ImageEndTokenID uint32 `json:"image_end_token_id"`
VideoStartTokenID uint32 `json:"video_start_token_id"`
VideoEndTokenID uint32 `json:"video_end_token_id"`
ImageTokenID uint32 `json:"image_token_id"`
VideoTokenID uint32 `json:"video_token_id"`
// Preprocessor config (preprocessor_config.json)
Preprocessor struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"-"`
}
var _ ModelConverter = (*glmOcrModel)(nil)
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.Preprocessor)
}
func (m *glmOcrModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "glmocr"
// Text model parameters
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
kv["glmocr.attention.key_length"] = headDim
kv["glmocr.attention.value_length"] = headDim
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
}
// Vision model parameters
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
// Preprocessor-derived image settings (min/max pixels and normalization)
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
if m.Preprocessor.Size.ShortestEdge > 0 {
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
}
if m.Preprocessor.Size.LongestEdge > 0 {
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
}
if len(m.Preprocessor.ImageMean) == 3 {
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
}
if len(m.Preprocessor.ImageStd) == 3 {
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
}
// Special tokens
kv["glmocr.image_token_id"] = m.ImageTokenID
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
kv["glmocr.video_token_id"] = m.VideoTokenID
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
return kv
}
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
skipLayer := func(name string) bool {
// Tensor names are already replaced to "blk.N.xxx" format
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(name)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return blkNum >= numLayers
}
for _, t := range ts {
name := t.Name()
// Skip next-n prediction layers (layers >= num_hidden_layers)
if skipLayer(name) {
continue
}
// Split ffn_gate_up into separate gate and up projections
if strings.Contains(name, "ffn_gate_up") {
for t := range splitDim(t, 0,
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
) {
out = append(out, t)
}
continue
}
if strings.HasSuffix(name, "patch_embd.weight") {
shape := t.Shape()
if len(shape) == 5 && shape[2] == 2 {
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
t0 := t.Clone()
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t0,
})
t1 := t.Clone()
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t1,
})
continue
}
if len(shape) == 4 {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
// Fall through to default handling
}
// Handle pre-split patch embedding weights
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
if strings.Contains(name, "patch_embd.0.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.Contains(name, "patch_embd.1.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Handle .weight.0 and .weight.1 suffix patterns
if strings.HasSuffix(name, "patch_embd.weight.0") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.HasSuffix(name, "patch_embd.weight.1") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
// We permute at conversion time so the weights work correctly with GGML's kernel
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
// Get config values for permutation
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
// Use appropriate head count: nHeads for Q, nKVHeads for K
effectiveHeads := nHeads
if strings.Contains(name, "attn_k.") {
effectiveHeads = nKVHeads
}
permutedT := t.Clone()
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: permutedT,
})
continue
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *glmOcrModel) Replacements() []string {
return []string{
// Vision encoder
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
"model.visual.patch_embed.proj", "v.patch_embd",
"model.visual.blocks", "v.blk",
"model.visual.post_layernorm", "v.post_ln",
"model.visual.downsample", "mm.patch_merger",
// Vision attention
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"attn.q_norm", "attn_q_norm",
"attn.k_norm", "attn_k_norm",
// Vision norms
"norm1", "ln1",
"norm2", "ln2",
// Vision MLP
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
// Merger (multimodal projector)
"model.visual.merger.proj", "mm.model.fc",
"model.visual.merger.post_projection_norm", "mm.post_norm",
"model.visual.merger.gate_proj", "mm.gate",
"model.visual.merger.up_proj", "mm.up",
"model.visual.merger.down_proj", "mm.down",
// Language model
"model.language_model.embed_tokens", "token_embd",
"model.language_model.layers", "blk",
"model.language_model.norm", "output_norm",
"lm_head", "output",
// Language model attention
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
// Language model norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"post_self_attn_layernorm", "post_attn_norm",
"post_mlp_layernorm", "post_ffn_norm",
// Language model MLP (remove mlp. prefix so ffn_* names work)
"mlp.gate_up_proj", "ffn_gate_up",
"mlp.down_proj", "ffn_down",
}
}

View File

@@ -0,0 +1,512 @@
package convert
import (
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
if q.NumAttentionHeads == 0 {
return fmt.Errorf("qwen3next: num_attention_heads must be set")
}
if q.NumKeyValueHeads == 0 {
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
}
if q.HeadDim == 0 {
return fmt.Errorf("qwen3next: head_dim must be set")
}
if q.RopeTheta == 0 {
return fmt.Errorf("qwen3next: rope_theta must be set")
}
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
}
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return nil
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
}
kv["attention.key_length"] = headDim
kv["attention.value_length"] = headDim
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
if q.SharedExpertIntermSize > 0 {
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
return kv
}
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
continue
}
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
}
return out
}
type qkvzSplitSpec struct {
hidden int
headKDim int
headVDim int
numKHeads int
numVHeads int
qkvzDim int
qkvOut int
gateOut int
}
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
if len(shape) != 2 {
return qkvzSplitSpec{}, false
}
numKHeads := int(q.LinearNumKeyHeads)
numVHeads := int(q.LinearNumValueHeads)
headKDim := int(q.LinearKeyHeadDim)
headVDim := int(q.LinearValueHeadDim)
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
return qkvzSplitSpec{}, false
}
if numVHeads%numKHeads != 0 {
return qkvzSplitSpec{}, false
}
hidden := int(shape[1])
vPerHead := headVDim * (numVHeads / numKHeads)
qkvzDim := 2*headKDim + 2*vPerHead
expectedOut := qkvzDim * numKHeads
if int(shape[0]) != expectedOut {
return qkvzSplitSpec{}, false
}
return qkvzSplitSpec{
hidden: hidden,
headKDim: headKDim,
headVDim: headVDim,
numKHeads: numKHeads,
numVHeads: numVHeads,
qkvzDim: qkvzDim,
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
gateOut: headVDim * numVHeads,
}, true
}
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
spec, ok := q.qkvzSpec(t.Shape())
if !ok {
return nil, nil, false
}
qkvTensor := t.Clone()
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
gateTensor := t.Clone()
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
return &ggml.Tensor{
Name: qkvName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
WriterTo: qkvTensor,
}, &ggml.Tensor{
Name: gateName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
WriterTo: gateTensor,
}, true
}
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
return nil, err
}
offset := 0
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
offset += vPerHead
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
var out tensor.Tensor
if extractGate {
out = zMat
} else {
out, err = tensor.Concat(1, qMat, kMat, vMat)
if err != nil {
return nil, err
}
}
out = tensor.Materialize(out)
out, err = tensor.Transpose(out, 1, 0)
if err != nil {
return nil, err
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(out.(*tensor.Dense))
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
n, err := n.Add(ones)
if err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 0)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
// Full attention (self_attn)
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
"linear_attn.A_log", "ssm_a",
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
!strings.HasPrefix(st.name, "mm.") &&
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -71,6 +71,10 @@
{
"source": "/api",
"destination": "/api/introduction"
},
{
"source": "/integrations/clawdbot",
"destination": "/integrations/openclaw"
}
],
"navigation": {
@@ -103,6 +107,7 @@
"pages": [
"/integrations/claude-code",
"/integrations/cline",
"/integrations/openclaw",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",

View File

@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -134,22 +134,12 @@ success
### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0`
#### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com

View File

@@ -0,0 +1,50 @@
---
title: OpenClaw
---
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch openclaw
```
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch openclaw --config
```
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -9,7 +9,7 @@ OpenCode is an open-source AI coding assistant that runs in your terminal.
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install.sh | bash
curl -fsSL https://opencode.ai/install | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>

View File

@@ -201,7 +201,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},

View File

@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 4096,
"": 0,
"2048": 2048,
}

View File

@@ -268,8 +268,10 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"glmocr",
"lfm2",
}, kv.Architecture())
}
@@ -859,11 +861,13 @@ func (f GGML) FlashAttention() bool {
"bert",
"gemma3",
"glm4moelite",
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))
}

1
go.mod
View File

@@ -27,6 +27,7 @@ require (
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0

3
go.sum
View File

@@ -174,6 +174,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -304,6 +306,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=

View File

@@ -75,3 +75,10 @@ type Cache interface {
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
// CheckpointCache optionally supports restoring recurrent state to a prior
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
// The returned position is the number of tokens that can be kept (prefix length).
type CheckpointCache interface {
PrepareRestore(seq int, targetPos int32) (int32, bool)
}

View File

@@ -0,0 +1,276 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan <jmorganca@gmail.com>
Date: Tue, 3 Feb 2026 12:00:00 -0800
Subject: [PATCH] ggml: metal solve_tri
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
7 files changed, 177 insertions(+)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 680904d13..83385c9ef 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_SOLVE_TRI);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_solve_tri_f32");
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index 0a8b9211a..8a9d17460 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index 7b5ee968c..4e5acfbe5 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_SOLVE_TRI:
+ return ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32;
+ case GGML_OP_COUNT_EQUAL:
+ return has_simdgroup_reduction &&
+ op->src[0]->type == GGML_TYPE_I32 &&
+ op->src[1]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 8944b07e9..cfdea9c07 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int64_t ne00;
int64_t ne01;
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 80864f303..4ac135603 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ const int64_t ncols = ne10;
+ const int64_t n_batches = (int64_t)ne02 * ne03;
+ const int64_t nr = n_batches * ncols;
+
+ int nth = 64;
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ if (nth < 1) {
+ nth = 1;
+ }
+
+ const int64_t n_tg = (nr + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
index 902b54452..a475183d3 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d33c16079..c37447a10 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const uint64_t ncols = (uint64_t) args.ne10;
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
+ const uint64_t nr = n_batches * ncols;
+
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
+ if (gid >= nr) {
+ return;
+ }
+
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
+ const uint64_t i02 = rem / ncols;
+ const uint64_t i01 = rem - i02 * ncols;
+
+ const uint64_t sa0 = args.nb00 / sizeof(float);
+ const uint64_t sa1 = args.nb01 / sizeof(float);
+ const uint64_t sa2 = args.nb02 / sizeof(float);
+ const uint64_t sa3 = args.nb03 / sizeof(float);
+
+ const uint64_t sb0 = args.nb10 / sizeof(float);
+ const uint64_t sb1 = args.nb11 / sizeof(float);
+ const uint64_t sb2 = args.nb12 / sizeof(float);
+ const uint64_t sb3 = args.nb13 / sizeof(float);
+
+ const uint64_t sx0 = args.nb0 / sizeof(float);
+ const uint64_t sx1 = args.nb1 / sizeof(float);
+ const uint64_t sx2 = args.nb2 / sizeof(float);
+ const uint64_t sx3 = args.nb3 / sizeof(float);
+
+ device const float * A = (device const float *) src0;
+ device const float * B = (device const float *) src1;
+ device float * X = (device float *) dst;
+
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
+
+ const uint64_t n = (uint64_t) args.ne11;
+
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
+ float sum = 0.0f;
+ for (uint64_t t = 0; t < i00; ++t) {
+ sum += A[A_base + i00 * sa1 + t * sa0] *
+ X[X_base + t * sx1 + i01 * sx0];
+ }
+
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
+ X[X_base + i00 * sx1 + i01 * sx0] =
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
+ }
+}
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -80,6 +80,7 @@ type LlamaServer interface {
GetPort() int
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
HasExited() bool
ContextLength() int
}
// llmServer is an instance of a runner hosting a single model
@@ -1200,7 +1201,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do load request: %w", err)
slog.Error("do load request", "error", err)
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
}
defer resp.Body.Close()
@@ -1901,6 +1903,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
return 0
}
func (s *llmServer) ContextLength() int {
return s.options.NumCtx
}
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
devices, err := ml.GetDevicesFromRunner(ctx, s)
if err != nil {

View File

@@ -170,6 +170,7 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
GELU_ERF(ctx Context) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
@@ -206,6 +207,32 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Exp(ctx Context) Tensor
Neg(ctx Context) Tensor
// Clamp clamps values to [min, max] range
Clamp(ctx Context, min, max float32) Tensor
// Softplus computes ln(1 + exp(x))
Softplus(ctx Context) Tensor
// CumSum computes cumulative sum along dimension 0
CumSum(ctx Context) Tensor
// Diag creates a diagonal matrix from a 1D tensor
Diag(ctx Context) Tensor
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
Tri(ctx Context, triType int) Tensor
// Fill fills a tensor with a constant value (in-place)
Fill(ctx Context, value float32) Tensor
// Repeat4D repeats tensor to match target shape
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
// SolveTri solves a triangular system Ax = B
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}

View File

@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
}
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
sched := C.ggml_backend_sched_new_ext(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
@@ -1581,6 +1581,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
@@ -1772,6 +1779,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
}
}
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
}
}
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
}
}
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
}
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {

View File

@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SOLVE_TRI);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_solve_tri_f32");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);

View File

@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:

View File

@@ -2385,6 +2385,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;

View File

@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
} break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_solve_tri args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
const int64_t ncols = ne10;
const int64_t n_batches = (int64_t)ne02 * ne03;
const int64_t nr = n_batches * ncols;
int nth = 64;
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
if (nth < 1) {
nth = 1;
}
const int64_t n_tg = (nr + nth - 1) / nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);

View File

@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -56,6 +56,18 @@ type fakeTensor struct {
Name string
}
// Stub methods to satisfy ml.Tensor interface
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}

View File

@@ -0,0 +1,174 @@
package glmocr
import (
"image"
"log/slog"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
temporalPatchSize int
spatialMergeSize int
minPixels int
maxPixels int
factor int
imageMean [3]float32
imageStd [3]float32
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
// Read normalization values from config if available, otherwise use CLIP defaults
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 336)),
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
spatialMergeSize: spatialMergeSize,
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
factor: patchSize * spatialMergeSize,
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
}
}
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
temporalFactor := p.temporalPatchSize
numFrames := temporalFactor // single image
if height < factor || width < factor {
// Scale up small images
scale := float64(factor) / float64(min(height, width))
height = int(math.Ceil(float64(height) * scale))
width = int(math.Ceil(float64(width) * scale))
}
if temporalFactor <= 0 {
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
temporalFactor = 1
}
if numFrames < temporalFactor {
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
numFrames = temporalFactor
}
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
if tBar*hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if tBar*hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
img = imageproc.Composite(img)
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
// We keep [C, H, W] for patch extraction
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
// Calculate grid dimensions (after Conv2D patching)
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // Single image
ImageHeight: resizedHeight,
ImageWidth: resizedWidth,
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, err
}
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := 3
patchSize := p.patchSize
mergeSize := p.spatialMergeSize
temporalPatchSize := p.temporalPatchSize
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + (py * patchSize) + px
result[dstIdx] = pixels[srcIdx]
}
}
if temporalPatchSize > 1 {
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -0,0 +1,235 @@
package glmocr
import (
"bytes"
"errors"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
PatchMerger *PatchMerger `gguf:"mm"`
ImageProcessor
imageTokenID int32
imageStartTokenID int32
imageEndTokenID int32
}
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: allEOS,
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: int32(c.Uint("image_token_id", 59280)),
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Blocks) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
// Create pixel values tensor from flattened patches
// Shape: [patchDim, numPatches]
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
// Forward through vision encoder
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
// Forward through downsample (patch merger)
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
return nil, errors.New("glmocr: missing vision downsample weights")
}
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
if m.PatchMerger == nil {
return nil, errors.New("glmocr: missing patch merger weights")
}
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
// Reset position cache
m.TextModel.positionCache = m.TextModel.positionCache[:0]
m.TextModel.ropeDelta = 0
pos := int32(0)
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
continue
}
// Get grid info for position calculation
grid := inp.Multimodal[0].Data.(*Grid)
mergedH := grid.Height / m.VisionModel.spatialMergeSize
mergedW := grid.Width / m.VisionModel.spatialMergeSize
// Add image start token
result = append(result, &input.Input{Token: m.imageStartTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
// Add image tokens with multimodal data
// All image tokens share the same base position for temporal dimension
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
basePos := pos
sameBatch := tokensPerGrid - 1
if sameBatch < 0 {
sameBatch = 0
}
result = append(result, &input.Input{
Token: m.imageTokenID,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: sameBatch,
})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
// Add placeholder tokens for remaining positions
// All image tokens use the same base position (temporal stays constant)
for range tokensPerGrid - 1 {
result = append(result, &input.Input{Token: m.imageTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
}
// Advance position by max(mergedH, mergedW) after image tokens
pos = basePos + int32(max(mergedH, mergedW))
// Add image end token
result = append(result, &input.Input{Token: m.imageEndTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
}
// Compute rope delta for continuation after the prefill segment:
// delta = (max_position_id + 1) - sequence_length
if len(m.TextModel.positionCache) > 0 {
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
ctx.Forward(hiddenStates)
// Build position slices for M-RoPE
positionSlice := func() [][]int32 {
s := [][]int32{
make([]int32, len(batch.Positions)), // temporal
make([]int32, len(batch.Positions)), // height
make([]int32, len(batch.Positions)), // width
make([]int32, len(batch.Positions)), // unused (zeros)
}
for i, position := range batch.Positions {
// Translate through position cache or continue sequence
if position < int32(len(m.TextModel.positionCache)) {
position = m.TextModel.positionCache[position]
} else if len(m.TextModel.positionCache) > 0 {
// Continue sequence after cached positions using ropeDelta
position = position + m.TextModel.ropeDelta
}
s[0][i] = position
s[1][i] = position
s[2][i] = position
}
return s
}()
// Inject vision embeddings and adjust positions for image tokens
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
w := grid.Width / m.VisionModel.spatialMergeSize
for i := range img.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
// Process through transformer layers
for i, layer := range m.TextModel.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glmocr", New)
}

View File

@@ -0,0 +1,190 @@
package glmocr
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type TextModelOptions struct {
hiddenSize int
numHeads int
numKVHeads int
headDim int
rotaryDim int
intermediateSize int
eps float32
ropeBase float32
mropeSections []int
}
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
// With 4 sections for [temporal, height, width, unused]
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Separate Q, K, V projections
q := sa.Query.Forward(ctx, hiddenStates)
k := sa.Key.Forward(ctx, hiddenStates)
v := sa.Value.Forward(ctx, hiddenStates)
// Reshape for GQA
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
// Apply M-RoPE (multi-resolution rotary position embeddings)
q = opts.applyMRoPE(ctx, q, positions)
k = opts.applyMRoPE(ctx, k, positions)
// Scaled dot-product attention with KV cache
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
return sa.Output.Forward(ctx, kqv)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type TextDecoderLayer struct {
// Input layernorm (before attention)
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
// Post self-attention layernorm (after attention, before residual add)
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
// FFN input layernorm (after first residual, before MLP)
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
// Post MLP layernorm (after MLP, before residual add)
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
}
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
// Attention block
residual := hiddenStates
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
// Prune to output positions in final layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP block
residual = hiddenStates
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextDecoderLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextModelOptions
// positionCache stores the M-RoPE position for each token in the sequence.
// This is needed because image tokens share the same base position but have
// different height/width offsets, and the end token position depends on the
// image grid dimensions.
positionCache []int32
ropeDelta int32
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// Clear position cache when KV cache shifts
m.positionCache = nil
m.ropeDelta = 0
return m.applyMRoPE(ctx, key, shift), nil
}
func newTextModel(c fs.Config) *TextModel {
hiddenSize := int(c.Uint("embedding_length", 1536))
numHeads := int(c.Uint("attention.head_count", 16))
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
intermediateSize := int(c.Uint("feed_forward_length", 4608))
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
ropeBase := c.Float("rope.freq_base", 10000)
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
if ropeDim <= 0 {
ropeDim = headDim
}
mropeSections := c.Ints("rope.mrope_section")
var sectionInts []int
if len(mropeSections) > 0 {
sectionInts = make([]int, len(mropeSections))
for i, section := range mropeSections {
sectionInts[i] = int(section)
}
} else {
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
// For rotaryDim=64 this yields [8, 12, 12].
total := ropeDim / 2
if total <= 0 {
total = 32
}
s0 := total * 2 / 8
s1 := total * 3 / 8
s2 := total - s0 - s1
sectionInts = []int{s0, s1, s2}
}
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
rotaryDim := ropeDim
return &TextModel{
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
TextModelOptions: &TextModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
rotaryDim: rotaryDim,
intermediateSize: intermediateSize,
eps: eps,
ropeBase: ropeBase,
mropeSections: sectionInts,
},
}
}

View File

@@ -0,0 +1,355 @@
package glmocr
import (
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type Grid struct {
Height int // Number of patches in height direction
Width int // Number of patches in width direction
Temporal int
ImageHeight int // Full image height in pixels
ImageWidth int // Full image width in pixels
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
numChannels int
patchSize int
temporalPatchSize int
imageSize int
spatialMergeSize int
outHiddenSize int
intermediateSize int
eps float32
}
type VisionPatchEmbed struct {
Proj *nn.Conv2D `gguf:"patch_embd_0"`
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
Bias ml.Tensor `gguf:"patch_embd.bias"`
}
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
_ = grid // patches are already in merge-block order
// pixelValues shape: [patchDim, numPatches]
numPatches := pixelValues.Shape()[1]
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Slice temporal frames for Conv2D (simulate Conv3D)
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
hiddenStates = hiddenStates.Add(ctx, out1)
}
// Flatten to [hidden_size, num_patches]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
if pe.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
}
return hiddenStates
}
type VisionSelfAttention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Combined QKV projection: [3*hidden_size, batch_size]
qkv := sa.QKV.Forward(ctx, hiddenStates)
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
// ChunkSections returns views - must make contiguous before further operations
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
q := chunks[0].Contiguous(ctx)
k := chunks[1].Contiguous(ctx)
v := chunks[2].Contiguous(ctx)
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
// Apply Q-norm and K-norm after head reshape
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
q = sa.QNorm.Forward(ctx, q, opts.eps)
k = sa.KNorm.Forward(ctx, k, opts.eps)
// Apply rotary position embeddings with vision-style 2D positions.
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
// We provide H/W sections and leave the remaining sections empty.
ropeFreqBase := float32(10000.0)
section := opts.headDim / 4
if section <= 0 {
section = 1
}
sections := []int{section, section, 0, 0}
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Try flash attention first (ScaledDotProductAttention), fall back to manual
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
slog.Warn("glmocr: vision attention falling back to manual attention",
"batchSize", batchSize, "numHeads", opts.numHeads,
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
// Manual attention fallback
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
// Attention scores
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, scale)
kq = kq.Softmax(ctx)
// Attention output: v @ kq (note: v first)
kqv := v.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type VisionBlock struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Pre-norm architecture
residual := hiddenStates
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionDownsample struct {
*nn.Conv2D
}
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
// Apply spatial downsampling via Conv2D
// Input: [hidden_size, num_patches] where patches are in merge-block order
if d.Conv2D == nil || d.Weight == nil {
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
return hiddenStates // Return input unchanged as fallback
}
merge := opts.spatialMergeSize
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
// ggml semantics: result.ne[perm[i]] = input.ne[i]
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// Step 3: Apply Conv2D without bias (bias added after reshape)
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
s0, s1 := merge, merge
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
// Step 5: Add bias after reshape
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
if d.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
}
return hiddenStates
}
type PatchMerger struct {
// GGUF tags align with mm.* keys used by the model
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
}
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Linear projection
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
// Post-projection layer norm + GELU ERF
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.GELU_ERF(ctx)
// Force a copy to avoid in-place mutation issues with GELU_ERF
hiddenStates = hiddenStates.Contiguous(ctx)
// SwiGLU MLP: down(silu(gate(x)) * up(x))
gateOut := m.GateProj.Forward(ctx, hiddenStates)
upOut := m.UpProj.Forward(ctx, hiddenStates)
gate := gateOut.SILU(ctx, upOut)
return m.DownProj.Forward(ctx, gate)
}
type VisionModel struct {
PatchEmbed *VisionPatchEmbed
Blocks []VisionBlock `gguf:"blk"`
PostLN *nn.RMSNorm `gguf:"post_ln"`
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings from flattened patches
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
// Create position IDs for RoPE (spatial grid)
// Patches are already in merge-block order from preprocessing
positions := m.createPositions(ctx, grid)
// Process through vision blocks
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
}
// Post-layernorm
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
// Note: Downsample is now applied separately in Model.EncodeMultimodal
// so mm.patch_merger remains a distinct module
return hiddenStates
}
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
// Create spatial position IDs for vision RoPE
// Position layout: [height, width, height, width] - 4 sections for mrope
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
// This follows the GLM-OCR rot_pos_emb layout
numPatches := grid.Height * grid.Width
mergeRatio := m.spatialMergeSize
// Build position arrays in merge-block order
// Each merge_ratio x merge_ratio block of patches is grouped together
hpos := make([]int32, numPatches)
wpos := make([]int32, numPatches)
ptr := 0
for y := 0; y < grid.Height; y += mergeRatio {
for x := 0; x < grid.Width; x += mergeRatio {
for dy := range mergeRatio {
for dx := range mergeRatio {
hpos[ptr] = int32(y + dy)
wpos[ptr] = int32(x + dx)
ptr++
}
}
}
}
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
// keep remaining sections zeroed to match its conventions.
zeros := make([]int32, numPatches)
s := [][]int32{
hpos, // Section 0: height
wpos, // Section 1: width
zeros, // Section 2: unused
zeros, // Section 3: unused
}
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
}
func newVisionModel(c fs.Config) *VisionModel {
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
patchSize := int(c.Uint("vision.patch_size", 14))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
imageSize := int(c.Uint("vision.image_size", 336))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
return &VisionModel{
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
numChannels: numChannels,
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
imageSize: imageSize,
spatialMergeSize: spatialMergeSize,
outHiddenSize: outHiddenSize,
intermediateSize: intermediateSize,
eps: eps,
},
}
}

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
@@ -19,5 +20,6 @@ import (
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
_ "github.com/ollama/ollama/model/models/qwen3next"
_ "github.com/ollama/ollama/model/models/qwen3vl"
)

View File

@@ -0,0 +1,103 @@
package qwen3next
import (
"errors"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the attention layer requirements.
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
// Key differences from standard attention:
// - Q projection outputs 2x size (Q + gate interleaved)
// - Both Q and K have RMSNorm
// - Output is gated: attn * sigmoid(gate)
type FullAttention struct {
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
// Use Dim() instead of Shape() for consistent behavior during graph construction
hiddenDim := hiddenStates.Dim(0)
batchSize := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
}
}
headDim := opts.headDim()
numHeads := opts.numHeads
// Q projection outputs query + gate interleaved
qFull := sa.Query.Forward(ctx, hiddenStates)
// Reshape to [headDim * 2, numHeads, batchSize]
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
// Split Q and gate along dimension 0
// Q: first headDim elements, gate: second headDim elements
query := qFull.Slice(ctx, 0, 0, headDim, 1)
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
// Make query contiguous for further operations
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
// K and V projections
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
// Derive numKVHeads from tensor dimensions (per-layer value)
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Apply QK normalization
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
// Apply RoPE
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Standard attention computation
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Apply sigmoid gate
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
gateSigmoid := gate.Sigmoid(ctx)
attention = attention.Mul(ctx, gateSigmoid)
return sa.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,596 @@
package qwen3next
import (
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for full attention layers
// - per-sequence conv state for linear attention layers
// - per-sequence delta state for linear attention layers
//
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int // convKernelSize - 1
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
// Delta state dimensions
deltaStateSize int // headVDim * headVDim * numVHeads
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer delta state buffers (allocated lazily)
deltaCtxs map[int]ml.Context
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointDeltaCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewHybridCache(
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
convDim, convChannels, deltaStateSize int,
) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
convDim: convDim,
convChannels: convChannels,
deltaStateSize: deltaStateSize,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
deltaCtxs: make(map[int]ml.Context),
deltaStates: make(map[int]ml.Tensor),
checkpointCount: checkpointCountDefault,
checkpointMinPos: checkpointMinPosDefault,
checkpointInterval: checkpointIntervalDefault,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointDeltaCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.deltaCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointDeltaCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for recurrent layers
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
c.reserveCheckpoints = true
c.planCheckpoints(batch)
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero state for newly allocated slots
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = false
c.planCheckpoints(batch)
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros the recurrent state for the given slots across all layers.
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Zero conv states
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// Zero delta states
if len(c.deltaStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
for _, buf := range c.deltaStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
for _, buf := range c.deltaStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// Copy-on-write for recurrent state
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return kvcache.ErrNotSupported
}
if !c.restoreComplete(restore) {
return kvcache.ErrNotSupported
}
// If the recurrent slot is shared, detach it before applying a restore.
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
// Removal invalidates recurrent state
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *HybridCache) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.deltaStates[layer]; ok {
return buf
}
if _, ok := c.deltaCtxs[layer]; !ok {
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent delta state must stay in F32.
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
c.deltaStates[layer] = buf
return buf
}
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
}
// UpdateDeltaState writes a new delta state for current batch sequences.
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,498 @@
package qwen3next
import (
"log/slog"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
checkpointCountDefault = 32
checkpointMinPosDefault = int32(16)
checkpointIntervalDefault = int32(1280)
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
delta map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *HybridCache) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return kvcache.ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.delta {
buf := c.deltaBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.delta) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.deltaStates {
if entry.delta == nil || entry.delta[layer] == nil {
return false
}
}
return true
}
func (c *HybridCache) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.delta != nil {
if dstEntry.delta == nil {
dstEntry.delta = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.delta {
dst := c.ensureCheckpointDelta(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
if entry.delta == nil {
entry.delta = make(map[int]ml.Tensor)
}
if t, ok := entry.delta[layer]; ok {
return t
}
ctx, ok := c.checkpointDeltaCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointDeltaCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
entry.delta[layer] = t
return t
}
func (c *HybridCache) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *HybridCache) reserveCheckpointDelta(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointDelta(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -0,0 +1,300 @@
package qwen3next
import (
"errors"
"math"
"os"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
func newTestBackend(tb testing.TB) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
if err != nil {
tb.Fatal(err)
}
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
_ = f.Close()
tb.Fatal(err)
}
if err := f.Close(); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
b.Close()
})
return b
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCachePrepareRestore(t *testing.T) {
cache := NewHybridCache(nil, 1, 1, 1)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
backend := newTestBackend(t)
cache := NewHybridCache(nil, 1, 2, 2)
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
cache.slotForSeq[1] = 0
cache.slotForSeq[2] = 0
cache.refCount[0] = 2
cache.refCount[1] = 0
cache.freeSlots = []int{1}
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
t.Fatalf("Remove failed: %v", err)
}
if cache.slotForSeq[1] == cache.slotForSeq[2] {
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
}
if cache.slotForSeq[1] != 1 {
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
}
if cache.slotForSeq[2] != 0 {
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
}
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared")
}
}
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
// Only set conv checkpoint, not delta - making it incomplete
entry.conv = map[int]ml.Tensor{0: nil}
// entry.delta is not set, so checkpoint is incomplete
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, kvcache.ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Don't set convStates/deltaStates - with no layers to check,
// entryComplete will return true as long as entry.pos >= 0
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
// Test that restoreComplete returns true when no layers need checkpoints
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
// Test that ring buffer wrap-around reuses entries without clearing maps.
store := newSlotCheckpointStore(3)
// Fill the buffer
store.record(10)
store.record(20)
store.record(30)
// Create fake tensor data in the first entry's maps
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil // Simulated tensor reference
store.entries[0].delta = make(map[int]ml.Tensor)
store.entries[0].delta[0] = nil // Simulated tensor reference
// Record another entry, which should wrap around and overwrite entry 0
store.record(40)
// Verify the maps are still present (we reuse tensors)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].delta == nil {
t.Fatalf("expected delta map to be preserved on reuse")
}
// Verify the new position was recorded
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
// Test behavior when buffer is exactly at capacity
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
// Verify both checkpoints are accessible
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
// Test behavior with zero-size buffer
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
// Test pruning that removes all checkpoints
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
// Prune everything by setting threshold below all positions
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
// When all checkpoints are pruned, lastPos is reset to -1
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -0,0 +1,472 @@
package qwen3next
import (
"errors"
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const chunkSize = 64
// TriType constants for triangular matrix operations
const (
TriTypeUpperDiag = 0
TriTypeUpper = 1
TriTypeLowerDiag = 2
TriTypeLower = 3
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Masks holds pre-computed mask tensors for chunked attention
type Masks struct {
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
Diag ml.Tensor // causal + identity
}
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
// It implements the Operator interface directly.
type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
// Layer index for cache access (set during model construction)
Layer int
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
ones = ones.Fill(ctx, 1.0)
causalMask := ones.Tri(ctx, TriTypeLower)
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
onesVec = onesVec.Fill(ctx, 1.0)
identity := onesVec.Diag(ctx)
diagMask := causalMask.Add(ctx, identity)
return &Masks{
Causal: causalMask,
Identity: identity,
Diag: diagMask,
}
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
}
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
// Split beta and alpha
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Reshape to merge head dimensions
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
alphaSoftplus := alphaBiased.Softplus(ctx)
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
}
// Reshape conv states
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
// Concatenate with input for convolution
convInput := convStates.Concat(ctx, qkvMixed, 0)
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
convOutput = convOutput.SILU(ctx)
// Reshape for extraction
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
// Extract convolved Q, K, V
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
// Reshape to 4D
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
}
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
// Repeat interleave Q and K if numKHeads != numVHeads
if numKHeads != numVHeads {
repeatFactor := numVHeads / numKHeads
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
}
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
zSilu := z2D.SILU(ctx)
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
// Reshape for output projection
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
out := gdn.SSMOut.Forward(ctx, finalOutput)
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}
// deltaNetAutoregressive implements single-token state update.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Reshape gate and beta for broadcasting
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
// Apply exponential to gate
gT = gT.Exp(ctx)
// state = state * g_t
state = state.Mul(ctx, gT)
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
kvMem := state.Mul(ctx, kTUnsqueezed)
// Sum over dim=-2 (second dimension after permute)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kvMem = kvMem.SumRows(ctx)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
// v_t with singleton dimension
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
// delta = (v_t - kv_mem) * beta_t
vDiff := vT.Sub(ctx, kvMem)
delta := vDiff.Mul(ctx, betaT)
// state = state + k_t.unsqueeze(-1) * delta
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
state = state.Add(ctx, kTDelta)
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
stateQ := state.Mul(ctx, qTUnsqueezed)
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
}
// deltaNetChunked implements chunked computation for prefill.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetChunked(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nTokens := q.Dim(2)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Permute tensors for chunked computation
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Compute padding
pad := (chunkSize - nTokens%chunkSize) % chunkSize
nChunks := (nTokens + pad) / chunkSize
// Pad tensors
if pad > 0 {
q = q.Pad(ctx, 0, pad, 0, 0)
k = k.Pad(ctx, 0, pad, 0, 0)
v = v.Pad(ctx, 0, pad, 0, 0)
gate = gate.Pad(ctx, pad, 0, 0, 0)
beta = beta.Pad(ctx, 0, pad, 0, 0)
}
// Use pre-computed masks (passed in, not recreated)
causalMask := masks.Causal
identity := masks.Identity
diagMask := masks.Diag
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
// v_beta = v * beta, k_beta = k * beta
vBeta := v.Mul(ctx, beta)
kBeta := k.Mul(ctx, beta)
// Reshape for chunked computation
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
// g_cumsum = cumsum(gate)
gCumsum := gate.CumSum(ctx)
// Compute decay mask
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
decayMask := gcsBroadcast.Sub(ctx, gcsI)
decayMask = decayMask.Mul(ctx, diagMask)
decayMask = decayMask.Exp(ctx)
decayMask = decayMask.Mul(ctx, diagMask)
// k @ k_beta^T
kMulKBeta := k.Mulmat(ctx, kBeta)
// k_decay = k @ k_beta^T * decay_mask
kDecay := kMulKBeta.Mul(ctx, decayMask)
// attn = -k_decay * causal_mask
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
// Triangular solve: (I - attn_lower)^-1 @ attn
attnLower := attn.Mul(ctx, causalMask)
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
attn = linSolve.Mul(ctx, causalMask)
attn = attn.Add(ctx, identity4D)
// v = v_beta^T @ attn
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
v = vBetaT.Mulmat(ctx, attn)
// Compute g_exp for state update
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
gExp := gCumsumT.Exp(ctx)
// kbeta_gexp = k_beta * g_exp
kBetaGExp := kBeta.Mul(ctx, gExp)
// k_cumdecay = attn @ kbeta_gexp^T
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
attnKQ := k.Mulmat(ctx, q)
attnKQ = attnKQ.Mul(ctx, decayMask)
attnKQ = attnKQ.Mul(ctx, diagMask)
// Pre-compute g_last and key_gdiff
// g_last = view of last element in g_cumsum along chunk_size dimension
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
gLastExp := gLast.Exp(ctx)
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
gDiffExp := gDiff.Exp(ctx)
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
newState := state
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
// state^T - permute is needed but Contiguous creates a copy
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// v_prime = k_cumdecay @ state
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
// v_new = v - v_prime
vNew := vChunk.Sub(ctx, vPrime)
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// attn_inter = (q * g_exp) @ state
qGExp := qChunk.Mul(ctx, gExpChunk)
attnInter := stateT.Mulmat(ctx, qGExp)
// core_attn_out = attn_inter + attn @ v_new
vAttn := vNewT.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
if coreAttnOut == nil {
coreAttnOut = coreAttnOutChunk
} else {
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// state = state * g_last + kgdmulvnew
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
newState = newState.Mul(ctx, gExpLastReshaped)
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
}
// Final reshape
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
// Slice to remove padding
if pad > 0 {
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
}

View File

@@ -0,0 +1,383 @@
package qwen3next
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int
numKVHeads int
keyLength int
valueLength int
ropeDim int
eps float32
ropeBase float32
ropeScale float32
ropeType string
originalContextLength int
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
normTopKProb bool
// Linear attention (Gated Delta Net) config
ssmDInner int // d_inner = head_v_dim * num_v_heads
ssmDState int // head_k_dim
ssmNGroup int // num_k_heads
ssmDtRank int // num_v_heads
convKernelSize int // SSM conv kernel size
// Per-layer type from GGUF metadata
isRecurrent []bool
// Pre-computed masks for chunked attention (created once per forward pass)
masks *Masks
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
ropeDim := cmp.Or(o.ropeDim, o.headDim())
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
}
// Operator is the interface for attention-like operators
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// sparse implements MoE with shared experts
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
// Shared experts
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
// Router logits
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
// Softmax routing weights
routingWeights := routerLogits.Softmax(ctx)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
// Expert computation with SILU activation
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
experts := gateOut.SILU(ctx, upOut)
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
// Add shared experts if present
if mlp.SharedUp != nil {
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
sharedOut := sharedGate.SILU(ctx, sharedUp)
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
// Apply shared expert gating
if mlp.SharedGateInp != nil {
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
sharedGateVal = sharedGateVal.Sigmoid(ctx)
// Broadcast gate to match dimensions
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
}
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}
// dense implements standard feedforward
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
Operator Operator
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-attention norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Attention (full or linear)
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
if err != nil {
return nil, err
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// First residual connection
hiddenStates = hiddenStates.Add(ctx, residual)
// Save for FFN residual
ffnResidual := hiddenStates
// Post-attention norm (before FFN)
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
// FFN
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
// Second residual connection
return hiddenStates.Add(ctx, ffnResidual), nil
}
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
// Create masks once per forward pass
m.Options.masks = createMasks(ctx)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer head counts (for detecting layer type)
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
layers[i].Operator = &GatedDeltaNet{Layer: i}
} else {
layers[i].Operator = &FullAttention{}
}
if isMoE {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: func() int {
for _, v := range headCountKV {
if v > 0 {
return int(v)
}
}
return 0
}(),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attentionScale: float64(c.Float("attention.scale")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
isRecurrent: isRecurrent,
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
}
// Calculate cache dimensions
convDim := max(0, opts.convKernelSize-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
headVDim := 0
numVHeads := opts.ssmDtRank
if numVHeads > 0 {
headVDim = opts.ssmDInner / numVHeads
}
deltaStateSize := headVDim * headVDim * numVHeads
// Validate dimension assumption: headKDim == headVDim is required for state computations
headKDim := opts.ssmDState
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
return &m, nil
}
func init() {
model.Register("qwen3next", New)
}

17
model/parsers/glmocr.go Normal file
View File

@@ -0,0 +1,17 @@
package parsers
import "github.com/ollama/ollama/api"
// GlmOcrParser is the GLM46 parser with thinking disabled.
type GlmOcrParser struct {
GLM46Parser
}
func (p *GlmOcrParser) HasThinkingSupport() bool {
return false
}
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
@@ -17,12 +18,34 @@ const (
ministralCollectingToolArgs
)
// ministralEvent represents an event emitted during parsing
type ministralEvent interface {
isMinistralEvent()
}
type ministralEventContent struct {
content string
}
type ministralEventThinking struct {
thinking string
}
type ministralEventToolCall struct {
name string
args string // raw JSON string
}
func (ministralEventContent) isMinistralEvent() {}
func (ministralEventThinking) isMinistralEvent() {}
func (ministralEventToolCall) isMinistralEvent() {}
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
currentTool *api.Tool
pendingToolName string // stores tool name while collecting args
}
func (p *MinistralParser) HasToolSupport() bool {
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
return nil, fmt.Errorf("tool '%s' not found", n)
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
const (
ministralToolCallsTag = "[TOOL_CALLS]"
ministralThinkTag = "[THINK]"
ministralThinkEndTag = "[/THINK]"
ministralArgsTag = "[ARGS]"
)
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. The second return value indicates
// whether to keep looping (true when state transitions, false when waiting
// for more data).
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
var events []ministralEvent
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
bufStr := p.buffer.String()
// Check for [TOOL_CALLS] tag
if strings.Contains(bufStr, ministralToolCallsTag) {
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
return events, true
}
// Check for [THINK] tag
if strings.Contains(bufStr, ministralThinkTag) {
split := strings.SplitN(bufStr, ministralThinkTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
return events, true
}
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
overlapThink := overlap(bufStr, ministralThinkTag)
maxOverlap := max(overlapToolCalls, overlapThink)
if maxOverlap > 0 {
// Withhold the potential partial tag
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
trailingWS := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWS
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
}
// No tag found: emit content but withhold trailing whitespace
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralThinkEndTag) {
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
thinkingContent := split[0]
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
p.buffer.WriteString(after)
if len(thinkingContent) > 0 {
events = append(events, ministralEventThinking{thinking: thinkingContent})
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(before), &args); err != nil {
// todo - throw a better error
return "", "", calls, err
}
p.state = ministralCollectingContent
return events, true
}
call := api.ToolCall{
// Check for partial overlap with [/THINK]
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventThinking{thinking: unambiguous})
}
return events, false
}
// No tag found: emit all thinking content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, ministralEventThinking{thinking: bufStr})
}
return events, false
case ministralCollectingToolName:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralArgsTag) {
split := strings.SplitN(bufStr, ministralArgsTag, 2)
toolName := split[0]
after := split[1]
p.pendingToolName = toolName
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolArgs
return events, true
}
// Wait for more data
return events, false
case ministralCollectingToolArgs:
bufStr := p.buffer.String()
jsonEnd := findJSONEnd(bufStr)
if jsonEnd != -1 {
jsonStr := bufStr[:jsonEnd+1]
remaining := bufStr[jsonEnd+1:]
events = append(events, ministralEventToolCall{
name: p.pendingToolName,
args: jsonStr,
})
p.pendingToolName = ""
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = ministralCollectingContent
return events, true
}
// Wait for more data
return events, false
default:
panic("unexpected ministral event")
}
}
// parseEvents loops calling eat() until it returns false
func (p *MinistralParser) parseEvents() []ministralEvent {
var all []ministralEvent
keepLooping := true
for keepLooping {
var events []ministralEvent
events, keepLooping = p.eat()
all = append(all, events...)
}
return all
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentBuilder, thinkingBuilder strings.Builder
var toolCalls []api.ToolCall
for _, event := range events {
switch e := event.(type) {
case ministralEventContent:
contentBuilder.WriteString(e.content)
case ministralEventThinking:
thinkingBuilder.WriteString(e.thinking)
case ministralEventToolCall:
// Validate tool exists
tool, toolErr := toolByName(p.tools, e.name)
if toolErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
}
// Parse JSON arguments
var args api.ToolCallFunctionArguments
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Name: tool.Function.Name,
Arguments: args,
},
}
calls = append(calls, call)
return "", "", calls, nil
})
}
return "", "", calls, nil
}
return p.buffer.String(), thinking, calls, nil
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
}
// findJSONEnd finds the index of the closing brace that completes a JSON object.
// It properly handles nested objects, arrays, and strings (including escaped characters).
// Returns -1 if the JSON is not yet complete.
func findJSONEnd(s string) int {
depth := 0
inString := false
escaped := false
for i, r := range s {
if inString {
switch {
case escaped:
// If the previous character was a backslash, skip this character
escaped = false
case r == '\\':
// Mark the next character as escaped
escaped = true
case r == '"':
// End of string literal
inString = false
}
continue
}
switch r {
case '"':
// Start of string literal
inString = true
case '{', '[':
// Increase nesting level for objects and arrays
depth++
case '}', ']':
// Decrease nesting level
depth--
if depth == 0 {
// Reached the end of the root JSON structure
return i
}
}
}
return -1
}

View File

@@ -0,0 +1,545 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestMinistralParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []ministralEvent
}
cases := []struct {
desc string
tools []api.Tool
steps []step
think bool // whether to enable thinking support
}{
// Content streaming
{
desc: "simple content",
steps: []step{
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
ministralEventContent{content: "Hello, how can I help you?"},
}},
},
},
{
desc: "streaming content word by word",
steps: []step{
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
},
},
// Simple tool calls
{
desc: "simple tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
}},
},
},
{
desc: "tool call with nested object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "tool call with deeply nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
steps: []step{
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
}},
},
},
{
desc: "tool call with array of objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
steps: []step{
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
}},
},
},
{
desc: "tool call with escaped quotes in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
steps: []step{
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
}},
},
},
{
desc: "tool call with braces inside string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
steps: []step{
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
}},
},
},
{
desc: "empty JSON object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
steps: []step{
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "no_args", args: `{}`},
}},
},
},
{
desc: "JSON with newlines in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
steps: []step{
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
}},
},
},
{
desc: "backslash in string value",
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
steps: []step{
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
}},
},
},
// Content after tool call
{
desc: "content after tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// NOTE: It's unclear if this is valid Ministral output, but the parser
// currently treats text after a tool call as regular content. This test
// documents that behavior so we notice if it changes.
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": 1}`},
ministralEventContent{content: "some content after"},
}},
},
},
// Multiple tool calls
{
desc: "multiple tool calls in sequence",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "get_weather"}},
{Function: api.ToolFunction{Name: "get_time"}},
},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
}},
},
},
{
desc: "multiple tool calls streamed separately",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "tool_a"}},
{Function: api.ToolFunction{Name: "tool_b"}},
},
steps: []step{
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
}},
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
}},
},
},
// Streaming tool calls
{
desc: "streaming tool call with nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
{input: `]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "streaming with incomplete JSON waits for completion",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
{input: `"a": {`, wantEvents: []ministralEvent{}},
{input: `"b": 1`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
}},
},
},
// Partial tag handling
{
desc: "partial tool tag fakeout",
steps: []step{
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
},
},
{
desc: "tool call tag split across chunks",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_", wantEvents: []ministralEvent{}},
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "content before tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "hello"},
ministralEventToolCall{name: "get_weather", args: `{}`},
}},
},
},
{
desc: "whitespace between content and tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "tabs and newlines before tool call are trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "non-breaking space before tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "whitespace before THINK tag is trimmed",
steps: []step{
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "after"},
}},
},
},
{
desc: "trailing whitespace withheld then emitted",
steps: []step{
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
},
},
{
desc: "trailing newline withheld then emitted",
steps: []step{
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
},
},
// Thinking support
{
desc: "thinking content",
think: true,
steps: []step{
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking here"},
}},
{input: "content after", wantEvents: []ministralEvent{
ministralEventContent{content: "content after"},
}},
},
},
{
desc: "thinking with whitespace after end tag",
think: true,
steps: []step{
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "my thoughts"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "non-breaking space after think end tag is trimmed",
think: true,
steps: []step{
// \u00a0 is non-breaking space
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "partial think end tag",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
},
},
{
desc: "think tag fakeout",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
},
},
{
desc: "thinking then tool call",
think: true,
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "let me think"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
// Content then THINK tag transition
{
desc: "content then think tag",
steps: []step{
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "more"},
}},
},
},
// Unicode handling
{
desc: "unicode content",
steps: []step{
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
ministralEventContent{content: "你好 🌍 مرحبا"},
}},
},
},
{
desc: "unicode in tool args",
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
steps: []step{
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
}},
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := MinistralParser{}
parser.hasThinkingSupport = tc.think
parser.Init(tc.tools, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestMinistralParser_Errors(t *testing.T) {
t.Run("unknown tool returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
if err == nil {
t.Fatal("expected error for unknown tool")
}
})
t.Run("invalid JSON returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
})
}
func TestFindJSONEnd(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "simple object",
input: `{"a": 1}`,
expected: 7,
},
{
name: "nested object",
input: `{"a": {"b": 2}}`,
expected: 14,
},
{
name: "array inside object",
input: `{"items": [1, 2, 3]}`,
expected: 19,
},
{
name: "braces in string",
input: `{"template": "Hello {name}!"}`,
expected: 28,
},
{
name: "escaped quotes",
input: `{"msg": "say \"hi\""}`,
expected: 20,
},
{
name: "incomplete object",
input: `{"a": {"b": 1}`,
expected: -1,
},
{
name: "deeply nested",
input: `{"a": {"b": {"c": {"d": 1}}}}`,
expected: 28,
},
{
name: "object with trailing content",
input: `{"a": 1} extra`,
expected: 7,
},
{
name: "array",
input: `[{"a": 1}, {"b": 2}]`,
expected: 19,
},
{
name: "escaped backslash before quote",
input: `{"path": "C:\\"}`,
expected: 15,
},
{
name: "empty string",
input: "",
expected: -1,
},
{
name: "no opening brace",
input: "hello world",
expected: -1,
},
{
name: "only opening brace",
input: "{",
expected: -1,
},
{
name: "unclosed string",
input: `{"key": "unclosed`,
expected: -1,
},
{
name: "double escaped backslash then quote",
input: `{"path": "C:\\\\"}`,
expected: 17,
},
{
name: "unicode in key and value",
input: `{"키": "값"}`,
expected: 13,
},
{
name: "nested arrays",
input: `{"matrix": [[1, 2], [3, 4]]}`,
expected: 27,
},
{
name: "mixed nesting",
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
expected: 31,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findJSONEnd(tt.input)
if result != tt.expected {
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
}
})
}
}
func TestMinistralParser_HasToolSupport(t *testing.T) {
p := &MinistralParser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
p := &MinistralParser{hasThinkingSupport: false}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
p = &MinistralParser{hasThinkingSupport: true}
if !p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return true")
}
}

View File

@@ -3,6 +3,7 @@ package parsers
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
@@ -70,6 +71,8 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "glm-ocr":
return &GlmOcrParser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
@@ -114,3 +117,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
sb.WriteString(after)
return before, after // return events
}
// overlap returns the longest overlap between the suffix of s and the prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}

View File

@@ -11,7 +11,6 @@ import (
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`

109
model/renderers/glmocr.go Normal file
View File

@@ -0,0 +1,109 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GlmOcrRenderer struct{}
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
enableThinking := false
thinkingExplicitlySet := false
if thinkValue != nil {
enableThinking = thinkValue.Bool()
thinkingExplicitlySet = true
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>\n")
if message.Thinking != "" {
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
} else {
sb.WriteString("<think></think>")
}
if message.Content != "" {
sb.WriteString("\n" + strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
sb.WriteString("\n")
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
sb.WriteString("\n")
}
}
sb.WriteString("<|assistant|>\n")
if thinkingExplicitlySet && !enableThinking {
sb.WriteString("<think></think>\n")
}
return sb.String(), nil
}
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}

View File

@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user\n")
sb.WriteString(imStartTag + "user")
}
sb.WriteString("<tool_response>\n")
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
sb.WriteString("\n</tool_response>")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {

View File

@@ -1,6 +1,7 @@
package renderers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -127,8 +128,7 @@ fahrenheit
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
</tool_response>
<tool_response>
{"number": 6}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>assistant
`,
},
@@ -280,8 +279,7 @@ call tool<|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>assistant
`,
},
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
}
}
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"payload": "ok"}),
}},
}},
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
}
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
}
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
}
}
func TestQwen3ToolDefinitionTypes(t *testing.T) {
tests := []struct {
name string

View File

@@ -82,6 +82,8 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":

View File

@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
}
if c.cache != nil {
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
if numPast > 0 {
// Recurrent caches use checkpoints to pick a safe resume position.
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
numPast = restored
} else {
numPast = 0
}
} else if !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)

View File

@@ -740,7 +740,11 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
// If the sequence was replaced while this batch was computing, discard results.
if activeBatch.seqs[i] != seq {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
@@ -1358,7 +1362,7 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
// Dummy load to get the backend wired up
f, err := os.CreateTemp("", "*.bin")
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
defer f.Close()
@@ -1368,13 +1372,13 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
slog.Debug("dummy model load took", "duration", time.Since(startLoad))

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -12,18 +12,18 @@ func Execute(args []string) error {
}
var newRunner bool
var imageRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
imageRunner = true
mlxRunner = true
}
if imageRunner {
return imagerunner.Execute(args)
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {

View File

@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens := 768
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
// always include the last message
if i == n {
continue
}
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
ctxLen += imageNumTokens * len(m.Images)
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
}
if truncate && ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
} else {
n = i
}
}
currMsgIdx := n
if currMsgIdx > 0 {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
}
for cnt, msg := range msgs[currMsgIdx:] {
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {

View File

@@ -2,6 +2,7 @@ package server
import (
"bytes"
"context"
"testing"
"github.com/google/go-cmp/cmp"
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
})
}
}
func TestChatPromptTokenizeCalls(t *testing.T) {
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
model := Model{Template: tmpl}
cases := []struct {
name string
limit int
msgs []api.Message
maxTokenizes int
}{
{
name: "all messages fit",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 1,
},
{
name: "truncate to last message",
limit: 5,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 5,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizeCount := 0
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
tokenizeCount++
tokens, err := mockRunner{}.Tokenize(ctx, s)
return tokens, err
}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if tokenizeCount > tt.maxTokenizes {
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
}
})
}
}

View File

@@ -58,6 +58,48 @@ func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
switch {
// Full attention
case strings.HasSuffix(name, ".attn_q.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_k.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_v.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".attn_output.weight"):
return fsggml.TensorTypeQ4_K, true
// Linear attention (Gated Delta Net) after split
case strings.HasSuffix(name, ".attn_qkv.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_gate.weight"):
return fsggml.TensorTypeQ4_K, true
// SSM
case strings.HasSuffix(name, ".ssm_ba.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_out.weight"):
return fsggml.TensorTypeQ4_K, true
// MoE experts + shared experts
case strings.HasSuffix(name, ".ffn_down_exps.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_up_exps.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
return fsggml.TensorTypeQ4_K, true
}
return 0, false
}
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
// Ported from llama_tensor_get_type, removed unsupported quantization types
nExperts := max(1, kv.Uint("expert_count", 0))
@@ -217,6 +259,7 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// do not quantize expert gating tensors
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
// do not quantize positional embeddings and token types (BERT)
quantize = quantize && (name != "position_embd.weight")
@@ -244,6 +287,12 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
newType := fsggml.TensorType(t.Kind)
if quantize {
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3nextQuantType(name); ok {
return qt
}
}
// get more optimal quantization type based on the tensor shape, layer, etc.
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
if newType != defaultType {

View File

@@ -75,16 +75,12 @@ func experimentEnabled(name string) bool {
var useClient2 = experimentEnabled("client2")
// Low VRAM mode is based on the sum of total VRAM (not free) and triggers
// reduced context length on some models
var lowVRAMThreshold uint64 = 20 * format.GibiByte
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
lowVRAM bool
addr net.Addr
sched *Scheduler
defaultNumCtx int
}
func init() {
@@ -107,8 +103,12 @@ var (
errBadTemplate = errors.New("template error")
)
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
if opts.NumCtx == 0 {
opts.NumCtx = s.defaultNumCtx
}
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
@@ -140,20 +140,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
// This model is much more capable with a larger context, so set that
// unless it would penalize performance too much
if !s.lowVRAM && slices.Contains([]string{
"gptoss", "gpt-oss",
"qwen3vl", "qwen3vlmoe",
}, model.Config.ModelFamily) {
opts.NumCtx = max(opts.NumCtx, 8192)
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
@@ -1720,10 +1711,18 @@ func Serve(ln net.Listener) error {
for _, gpu := range gpus {
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
}
if totalVRAM < lowVRAMThreshold {
s.lowVRAM = true
slog.Info("entering low vram mode", "total vram", format.HumanBytes2(totalVRAM), "threshold", format.HumanBytes2(lowVRAMThreshold))
// Set default context based on VRAM tier
// Use slightly lower thresholds (47/23 GiB vs. 48/24 GiB) to account for small differences in the exact value
switch {
case totalVRAM >= 47*format.GibiByte:
s.defaultNumCtx = 262144
case totalVRAM >= 23*format.GibiByte:
s.defaultNumCtx = 32768
default:
s.defaultNumCtx = 4096
}
slog.Info("vram-based default context", "total_vram", format.HumanBytes2(totalVRAM), "default_num_ctx", s.defaultNumCtx)
err = srvr.Serve(ln)
// If server is closed from the signal handler, wait for the ctx to be done
@@ -1897,8 +1896,8 @@ func (s *Server) PsHandler(c *gin.Context) {
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
if v.Options != nil {
mr.ContextLength = v.Options.NumCtx
if v.llama != nil {
mr.ContextLength = v.llama.ContextLength()
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just

View File

@@ -15,6 +15,7 @@ import (
)
func TestGenerateDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -208,6 +209,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
func TestChatDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{

View File

@@ -20,6 +20,7 @@ import (
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
// when in chat-like flow (messages present, no suffix, no template)
func TestGenerateWithBuiltinRenderer(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -204,6 +205,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
func TestGenerateWithDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{

View File

@@ -162,6 +162,7 @@ func TestGenerateChatRemote(t *testing.T) {
}
func TestGenerateChat(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -878,6 +879,7 @@ func TestGenerateChat(t *testing.T) {
}
func TestGenerate(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -2355,6 +2357,7 @@ func TestGenerateWithImages(t *testing.T) {
// TestImageGenerateStreamFalse tests that image generation respects stream=false
// and returns a single JSON response instead of streaming ndjson.
func TestImageGenerateStreamFalse(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
p := t.TempDir()

View File

@@ -0,0 +1,127 @@
package server
import (
"testing"
)
func TestModelOptionsNumCtxPriority(t *testing.T) {
tests := []struct {
name string
envContextLen string // empty means not set (uses 0 sentinel)
defaultNumCtx int // VRAM-based default
modelNumCtx int // 0 means not set in model
requestNumCtx int // 0 means not set in request
expectedNumCtx int
}{
{
name: "vram default when nothing else set",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 32768,
},
{
name: "env var overrides vram default",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 8192,
},
{
name: "model overrides vram default",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 0,
expectedNumCtx: 16384,
},
{
name: "model overrides env var",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 0,
expectedNumCtx: 16384,
},
{
name: "request overrides everything",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "request overrides vram default",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "request overrides model",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "low vram tier default",
envContextLen: "",
defaultNumCtx: 4096,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 4096,
},
{
name: "high vram tier default",
envContextLen: "",
defaultNumCtx: 262144,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 262144,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set or clear environment variable
if tt.envContextLen != "" {
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
}
// Create server with VRAM-based default
s := &Server{
defaultNumCtx: tt.defaultNumCtx,
}
// Create model options (use float64 as FromMap expects JSON-style numbers)
var modelOpts map[string]any
if tt.modelNumCtx != 0 {
modelOpts = map[string]any{"num_ctx": float64(tt.modelNumCtx)}
}
model := &Model{
Options: modelOpts,
}
// Create request options (use float64 as FromMap expects JSON-style numbers)
var requestOpts map[string]any
if tt.requestNumCtx != 0 {
requestOpts = map[string]any{"num_ctx": float64(tt.requestNumCtx)}
}
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
t.Fatalf("modelOptions failed: %v", err)
}
if opts.NumCtx != tt.expectedNumCtx {
t.Errorf("NumCtx = %d, want %d", opts.NumCtx, tt.expectedNumCtx)
}
})
}
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,11 +563,20 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
server, err := mlxrunner.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -804,6 +804,7 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
func (s *mockLlm) ContextLength() int { return 0 }
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted when idle.

View File

@@ -3,13 +3,13 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImage = Capability("image")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImage = Capability("image")
)
func (c Capability) String() string {

View File

@@ -13,6 +13,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
@@ -34,7 +35,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
@@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
var parserName, rendererName string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
capabilities = append(capabilities, "thinking")
}
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, "", ""),
progressFn,
)
}
@@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
@@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
// Check architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
// Check the architecture list
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
return true
}
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
return true
}
}
}
return false
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}

View File

@@ -13,7 +13,11 @@ import (
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Supported quantization types:
// - "q4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "q8": affine 8-bit, group_size=64 (with qbiases)
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
@@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
case "q4":
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "q8":
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
case "mxfp8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}

View File

@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool {
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
return "q4"
case "Q8", "INT8", "FP8":
return "q8"
case "NVFP4":
return "nvfp4"
case "MXFP8":
return "mxfp8"
default:
return quantize
}
}
// getQuantGroupSize returns the group size for a given quantization type.
// These must match the values used in quantize.go when creating quantized models.
func getQuantGroupSize(quantize string) int {
switch normalizeQuantType(quantize) {
case "nvfp4":
return 16
case "q4":
return 32
case "mxfp8":
return 32
case "q8":
return 64
default:
return 32
}
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: q4 (less sensitive)
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
return ""
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
return false
// Normalize quantization type to canonical form
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, q4/mxfp8: 32, q8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "q8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return ""
}
return true
// Skip routing gate weights (should stay high precision)
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
return ""
}
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
return quantNorm
}
// Attention MLA weights - keep unquantized (bf16)
// These are highly sensitive: errors accumulate in the KV cache over time
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
if strings.Contains(name, "q_a_proj") ||
strings.Contains(name, "q_b_proj") ||
strings.Contains(name, "kv_a_proj") ||
strings.Contains(name, "kv_b_proj") {
return "" // No quantization - keep bf16
}
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "q8"
}
// Output projection, gate/up weights - use requested quantization (Q4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
strings.Contains(name, "up_proj") {
return quantNorm
}
// LM head - use requested quantization
if strings.Contains(name, "lm_head") {
return quantNorm
}
// Default to requested quantization for other weights
return quantNorm
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
@@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
quantizeType = quantize
if quantize != "" {
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
@@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
// Create model_index.json with quantization info if quantizing
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
"group_size": getQuantGroupSize(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal model_index.json: %w", err)
}
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
if err != nil {
return fmt.Errorf("failed to create model_index.json layer: %w", err)
}
layers = append(layers, indexLayer)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {

View File

@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
name string
tensor string
shape []int32
quantize string
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests
// FP8/FP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
// NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
@@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp4", "fp8":
case "", "q4", "q8", "nvfp4", "mxfp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
quantizeType = quantize
}
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
// MLX requires the last dimension to be divisible by the group size.
// nvfp4: 16, q4/mxfp8: 32, q8: 64
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
groupSize := int32(32)
switch strings.ToUpper(quantize) {
case "NVFP4":
groupSize = 16
case "Q8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0
}

View File

@@ -9,6 +9,7 @@ type Cache interface {
Offset() int
Len() int
State() []*mlx.Array
Reset()
}
type KVCache struct {
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// Reset clears the cache state for a new generation session
func (c *KVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
// Reset clears the cache state for a new generation session
func (c *RotatingKVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
c.idx = 0
}

View File

@@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
return filepath.Join(m.BlobDir, blobName)
}
// GetTensorLayers returns all tensor layers for a given component.
// Component should be "text_encoder", "transformer", or "vae".
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
// GetTensorLayers returns tensor layers, optionally filtered by component.
// If component is empty, returns all tensor layers (for LLM models).
// If component is specified (e.g., "text_encoder", "transformer", "vae"),
// returns only layers with that prefix.
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
if component == "" || strings.HasPrefix(layer.Name, component+"/") {
layers = append(layers, layer)
}
}
@@ -206,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
info.Quantization = "Q8"
break
}
}

View File

@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Stack stacks arrays along a new axis (axis 0 by default)
func Stack(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)

View File

@@ -0,0 +1,840 @@
//go:build mlx
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds RoPE scaling configuration
type RopeScaling struct {
Factor float32 `json:"factor"`
MscaleAllDim float32 `json:"mscale_all_dim"`
}
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// RoPE scaling
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
}
// MLAAttention implements Multi-head Latent Attention with absorption.
// This uses absorbed MLA which operates in latent space for reduced KV cache.
type MLAAttention struct {
// Low-rank query projections
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
// Low-rank KV projections (with shared rope component)
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
}
// Forward computes absorbed MLA attention output.
// This operates in latent space for reduced KV cache memory.
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Query path: q_a_proj -> layernorm -> q_b_proj
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
// Split Q into nope and rope parts
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
// KV path: get compressed KV and k_pe
compressedKV := a.KVAProjWithMQA.Forward(x)
// Split into compressed_kv and k_pe (shared rope component)
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
// Apply layernorm to get kv latent representation
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
kvLatent = mlx.ExpandDims(kvLatent, 1)
// Apply RoPE to the rope parts
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
// ABSORBED MLA: project q_nope to latent space
// qNope: [B, num_heads, L, qk_nope_head_dim]
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
// Result: [B, num_heads, L, kv_lora_rank]
qLatent := a.EmbedQ.Forward(qNope)
// Keys = concat(kvLatent, kPE)
// kvLatent: [B, 1, L, kv_lora_rank]
// kPE: [B, 1, L, qk_rope_head_dim]
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
// Cache the smaller latent representation
// We cache keys (latent + rope) and use empty values since values are derived from keys
cachedL := L
if c != nil {
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
keys, _ = c.Update(keys, placeholderValues, int(L))
cachedL = int32(keys.Shape()[2])
}
// Values are the first kv_lora_rank dims of keys (slice off rope part)
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
// Queries = concat(qLatent, qPE)
// qLatent: [B, num_heads, L, kv_lora_rank]
// qPE: [B, num_heads, L, qk_rope_head_dim]
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
// Attention in latent space
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
// values: [B, 1, cachedL, kv_lora_rank]
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
// ABSORBED MLA: unembed from latent space
// out: [B, num_heads, L, kv_lora_rank]
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
// Result: [B, num_heads, L, v_head_dim]
out = a.UnembedOut.Forward(out)
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
// Compute gate logits through linear layer (handles both quantized and non-quantized)
gates := g.Gate.Forward(x)
// Sigmoid scoring
scores := mlx.Sigmoid(gates)
origScores := scores
// Add correction bias if present
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
}
// Group-wise expert selection (simplified for n_group=1)
// Select top-k experts
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
shape := inds.Shape()
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
// Get scores for selected experts
scores = mlx.TakeAlongAxis(origScores, inds, -1)
// Normalize if configured
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
// Apply routing scaling factor
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
// Note: No weight tags - these are populated manually by stacking expert weights
type SwitchMLP struct {
// Dequantized weights (used when GatherQMM not available)
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
// Quantized weights (used with GatherQMM for 4/8-bit affine)
GateWeightQ, GateScales, GateBiases *mlx.Array
UpWeightQ, UpScales, UpBiases *mlx.Array
DownWeightQ, DownScales, DownBiases *mlx.Array
// Quantization bits per projection (supports mixed precision Q4/Q8)
GateBits int
UpBits int
DownBits int
// Quantization group size per projection (detected from tensor shapes)
GateGroupSize int
UpGroupSize int
DownGroupSize int
// If true, use GatherQMM with quantized weights
UseQuantized bool
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
topK := cfg.NumExpertsPerTok
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
// Flatten for gather_mm: [B*L, 1, 1, D]
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
// Flatten indices: [B, L, topK] -> [B*L, topK]
idxFlat := mlx.Reshape(indices, B*L, topK)
// Sort for efficient gather (when we have many tokens)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
// Reorder x based on sorted indices
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
var gate, up, hidden, down *mlx.Array
if s.UseQuantized {
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
// Use GatherMM for dequantized/non-quantized weights
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
// Unsort if we sorted
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
// Get expert indices and scores
inds, scores := m.Gate.Forward(x, cfg)
// Apply routed experts
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
// Add shared experts if present
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MLP with residual
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MoE with residual
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []Block `weight:"-"` // Loaded manually due to different block types
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead nn.LinearLayer `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
// computeScale computes the attention scale.
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
func computeScale(cfg *Config) float32 {
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
scale *= s * s
}
return scale
}
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
// Currently only 4-bit and 8-bit affine quantization are supported.
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
Scales *mlx.Array // Quantization scales (nil if not quantized)
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// getQuantParams returns quantization parameters from model metadata.
// Returns groupSize, bits, and mode for the model's quantization type.
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
// Use metadata group_size if available (overrides default)
if gs := weights.GroupSize(); gs > 0 {
groupSize = gs
}
return groupSize, bits, mode
}
// loadExpertWeight loads an expert weight.
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
// Otherwise dequantizes and returns only the weight.
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
w, _ := weights.GetTensor(path + ".weight")
if w == nil {
return nil
}
// Check if this is a quantized weight by looking for scales
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
scales, _ := weights.GetTensor(scalePath)
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Get quantization params from metadata
groupSize, bits, mode := getQuantParams(weights)
// Update config with group size (for GatherQMM calls)
if cfg.QuantGroupSize == 0 {
cfg.QuantGroupSize = groupSize
}
// If GatherQMM is supported and requested, return quantized components
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
}
// Otherwise dequantize
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
}
return &ExpertWeight{Weight: w}
}
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
// Returns embed_q and unembed_out weights for per-head projections.
//
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
// Output:
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
path := prefix + ".self_attn.kv_b_proj"
w, err := weights.GetTensor(path + ".weight")
if err != nil || w == nil {
return nil, nil
}
// Check if quantized and dequantize
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
scales, _ := weights.GetTensor(scalePath)
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := getQuantParams(weights)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
// Split into wk and wv
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
// wv: [num_heads, v_head_dim, kv_lora_rank]
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
// Transform for absorbed MLA:
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
embedQ := mlx.Transpose(wk, 0, 2, 1)
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
unembedOut := wv
return embedQ, unembedOut
}
// StackedExpertWeights holds stacked weights for all experts.
type StackedExpertWeights struct {
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
Scales *mlx.Array // Stacked scales (nil if not quantized)
Biases *mlx.Array // Stacked biases (nil if not quantized)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
func collectAndStackExpertWeights(
weights safetensors.WeightSource,
prefix string,
projName string,
numExperts int32,
useQuantized bool,
cfg *Config,
) *StackedExpertWeights {
var w, s, b []*mlx.Array
var bits, groupSize int
for e := int32(0); e < numExperts; e++ {
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
ew := loadExpertWeight(weights, path, useQuantized, cfg)
if ew == nil {
continue
}
w = append(w, ew.Weight)
if ew.Scales != nil {
s = append(s, ew.Scales)
}
if ew.Biases != nil {
b = append(b, ew.Biases)
}
if e == 0 {
bits = ew.Bits
groupSize = ew.GroupSize
}
}
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
if len(w) > 0 {
result.Weight = mlx.Stack(w, 0)
if len(s) > 0 {
result.Scales = mlx.Stack(s, 0)
}
if len(b) > 0 {
result.Biases = mlx.Stack(b, 0)
}
}
return result
}
// sanitizeExpertWeights stacks individual expert weights into tensors.
// If useQuantized is true and weights support GatherQMM, returns quantized components.
// Otherwise returns dequantized weights with nil scales/biases.
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
return gate, up, down
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
if err := weights.Load(0); err != nil {
return nil, fmt.Errorf("load weight data: %w", err)
}
// Set up quantization parameters (only if model is actually quantized)
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
quantization := weights.Quantization()
useQuantized := false
if quantization != "" {
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
// Build tokenizer config with companion files for EOS/BOS token loading
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData, // Already loaded above, contains eos_token_id
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
}
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights (pass cfg so group sizes can be detected)
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
switchMLP.GateWeightQ = gate.Weight
switchMLP.GateScales = gate.Scales
switchMLP.GateBiases = gate.Biases
switchMLP.GateBits = gate.Bits
switchMLP.GateGroupSize = gate.GroupSize
switchMLP.UpWeightQ = up.Weight
switchMLP.UpScales = up.Scales
switchMLP.UpBiases = up.Biases
switchMLP.UpBits = up.Bits
switchMLP.UpGroupSize = up.GroupSize
switchMLP.DownWeightQ = down.Weight
switchMLP.DownScales = down.Scales
switchMLP.DownBiases = down.Biases
switchMLP.DownBits = down.Bits
switchMLP.DownGroupSize = down.GroupSize
} else {
switchMLP.GateWeight = gate.Weight
switchMLP.UpWeight = up.Weight
switchMLP.DownWeight = down.Weight
}
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: switchMLP,
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return m.LMHead.Forward(h)
}
// Interface methods
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
// When think is true, the prompt ends with <think> to enable reasoning mode.
// When think is false, the prompt ends with </think> to skip reasoning.
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
if think {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
}
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
func (m *Model) NewRenderer() *Renderer {
return &Renderer{}
}
// NewParser returns a new Parser for extracting thinking and tool calls from output.
func (m *Model) NewParser() *Parser {
return &Parser{}
}

View File

@@ -0,0 +1,479 @@
//go:build mlx
package glm4_moe_lite
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type parserState int
const (
parserState_LookingForThinkingOpen parserState = iota
parserState_ThinkingStartedEatingWhitespace
parserState_CollectingThinking
parserState_ThinkingDoneEatingWhitespace
parserState_CollectingContent
parserState_ToolStartedEatingWhitespace
parserState_CollectingToolContent
)
const (
thinkingOpenTag = "<think>"
thinkingCloseTag = "</think>"
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type Parser struct {
state parserState
buffer strings.Builder
tools []api.Tool
}
// HasToolSupport returns true as GLM4 supports tool calling.
func (p *Parser) HasToolSupport() bool {
return true
}
// HasThinkingSupport returns true as GLM4 supports thinking mode.
func (p *Parser) HasThinkingSupport() bool {
return true
}
// Init initializes the parser with tools and thinking configuration.
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = parserState_CollectingThinking
}
return tools
}
type parserEvent interface {
isParserEvent()
}
type eventContent struct {
content string
}
func (eventContent) isParserEvent() {}
type eventRawToolCall struct {
raw string
}
func (eventRawToolCall) isParserEvent() {}
type eventThinkingContent struct {
content string
}
func (eventThinkingContent) isParserEvent() {}
// Add processes new output text and returns parsed content, thinking, and tool calls.
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case eventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case eventThinkingContent:
thinkingSb.WriteString(event.content)
case eventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Parser) parseEvents() []parserEvent {
var all []parserEvent
keepLooping := true
for keepLooping {
var events []parserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *Parser) eat() ([]parserEvent, bool) {
var events []parserEvent
switch p.state {
case parserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, thinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = parserState_ThinkingStartedEatingWhitespace
} else {
p.state = parserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = parserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case parserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
case parserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, eventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = parserState_ThinkingDoneEatingWhitespace
} else {
p.state = parserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
}
case parserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
case parserState_CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
before, after := p.splitAtTag(toolOpenTag, true)
if len(before) > 0 {
events = append(events, eventContent{content: before})
}
if after == "" {
p.state = parserState_ToolStartedEatingWhitespace
} else {
p.state = parserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
}
case parserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
case parserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, toolCloseTag) {
toolContent, _ := p.splitAtTag(toolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm4 tool call closing tag found but no content before it")
}
events = append(events, eventRawToolCall{raw: toolContent})
p.state = parserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// overlap returns the length of the overlap between the end of s and the start of tag.
func overlap(s, tag string) int {
for i := 1; i <= len(tag) && i <= len(s); i++ {
if strings.HasSuffix(s, tag[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length of trailing whitespace in s.
func trailingWhitespaceLen(s string) int {
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
return len(s) - len(trimmed)
}
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
type ToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeContent(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
escaped := escapeContent(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed ToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
func parseValue(value string, paramType api.PropertyType) any {
value = strings.TrimSpace(value)
// If no type specified, return as string
if len(paramType) == 0 {
return value
}
// Try to parse based on specified types
for _, t := range paramType {
switch t {
case "boolean":
if value == "true" {
return true
}
if value == "false" {
return false
}
case "integer":
var i int64
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
return i
}
case "number":
var f float64
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
return f
}
case "array", "object":
// Try to parse as JSON
var result any
if err := json.Unmarshal([]byte(value), &result); err == nil {
return result
}
}
}
// Default to string
return value
}

View File

@@ -0,0 +1,192 @@
//go:build mlx
package glm4_moe_lite
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestParserThinking(t *testing.T) {
tests := []struct {
name string
input string
thinkEnabled bool
wantContent string
wantThinking string
wantToolCalls int
}{
{
name: "thinking enabled - simple thinking then content",
input: "Let me think about this...</think>Here is my answer.",
thinkEnabled: true,
wantThinking: "Let me think about this...",
wantContent: "Here is my answer.",
},
{
name: "thinking enabled - only thinking",
input: "I need to consider multiple factors...",
thinkEnabled: true,
wantThinking: "I need to consider multiple factors...",
wantContent: "",
},
{
name: "thinking disabled - direct content",
input: "Here is my direct answer.",
thinkEnabled: false,
wantThinking: "",
wantContent: "Here is my direct answer.",
},
{
name: "thinking with tool call",
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
thinkEnabled: true,
wantThinking: "Let me search for that...",
wantContent: "I'll use a tool.",
wantToolCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Parser{}
var thinkValue *api.ThinkValue
if tt.thinkEnabled {
thinkValue = &api.ThinkValue{Value: true}
} else {
thinkValue = &api.ThinkValue{Value: false}
}
// Define tools for tool call tests
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "search",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
p.Init(tools, nil, thinkValue)
content, thinking, calls, err := p.Add(tt.input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if thinking != tt.wantThinking {
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
}
if content != tt.wantContent {
t.Errorf("content = %q, want %q", content, tt.wantContent)
}
if len(calls) != tt.wantToolCalls {
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
}
})
}
}
func TestParserToolCall(t *testing.T) {
p := &Parser{}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
// Initialize with thinking disabled
tv := &api.ThinkValue{Value: false}
p.Init(tools, nil, tv)
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
_, _, calls, err := p.Add(input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
call := calls[0]
if call.Function.Name != "get_weather" {
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
}
location, ok := call.Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Errorf("location = %v, want %q", location, "San Francisco")
}
unit, ok := call.Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Errorf("unit = %v, want %q", unit, "celsius")
}
}
func TestOverlap(t *testing.T) {
tests := []struct {
s string
tag string
want int
}{
{"hello<", "</think>", 1},
{"hello</", "</think>", 2},
{"hello</t", "</think>", 3},
{"hello</th", "</think>", 4},
{"hello</thi", "</think>", 5},
{"hello</thin", "</think>", 6},
{"hello</think", "</think>", 7},
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
{"hello", "</think>", 0},
{"", "</think>", 0},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
got := overlap(tt.s, tt.tag)
if got != tt.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
}
})
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
tests := []struct {
s string
want int
}{
{"hello ", 3},
{"hello\n\t ", 3},
{"hello", 0},
{"", 0},
{" ", 3},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
got := trailingWhitespaceLen(tt.s)
if got != tt.want {
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// Renderer renders messages for GLM4-MoE-Lite models.
//
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type Renderer struct{}
// Render renders messages into the GLM4 chat format.
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
// renderToolArguments converts tool call arguments to GLM4 XML format.
func renderToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
func formatToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,205 @@
//go:build mlx
package glm4_moe_lite
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestRendererSimple(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
// Thinking enabled (default)
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererThinkingDisabled(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
tv := &api.ThinkValue{Value: false}
result, err := r.Render(messages, nil, tv)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererMultiTurn(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
{Role: "user", Content: "And 3+3?"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check key parts
if !strings.Contains(result, "[gMASK]<sop>") {
t.Error("missing [gMASK]<sop> prefix")
}
if !strings.Contains(result, "<|user|>What is 2+2?") {
t.Error("missing first user message")
}
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
t.Error("missing assistant message with thinking")
}
if !strings.Contains(result, "<|user|>And 3+3?") {
t.Error("missing second user message")
}
if !strings.HasSuffix(result, "<|assistant|><think>") {
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
}
}
func TestRendererWithSystem(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
t.Error("missing system message")
}
}
func TestRendererWithTools(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What's the weather?"},
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"location"},
},
},
},
}
result, err := r.Render(messages, tools, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check for tool system prompt
if !strings.Contains(result, "<|system|>") {
t.Error("missing system tag for tools")
}
if !strings.Contains(result, "# Tools") {
t.Error("missing tools header")
}
if !strings.Contains(result, "<tools>") {
t.Error("missing tools tag")
}
if !strings.Contains(result, "get_weather") {
t.Error("missing tool name")
}
if !strings.Contains(result, "</tools>") {
t.Error("missing closing tools tag")
}
}
func TestRendererWithToolCalls(t *testing.T) {
r := &Renderer{}
args := api.NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
messages := []api.Message{
{Role: "user", Content: "What's the weather in SF?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<tool_call>get_weather") {
t.Error("missing tool call")
}
if !strings.Contains(result, "<arg_key>location</arg_key>") {
t.Error("missing arg_key")
}
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
t.Error("missing arg_value")
}
if !strings.Contains(result, "</tool_call>") {
t.Error("missing tool call closing tag")
}
if !strings.Contains(result, "<|observation|>") {
t.Error("missing observation tag")
}
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
t.Error("missing tool response")
}
}
func TestFormatToolJSON(t *testing.T) {
input := []byte(`{"name":"test","value":123}`)
result := formatToolJSON(input)
// Should add spaces after : and ,
if !strings.Contains(result, ": ") {
t.Error("should add space after colon")
}
if !strings.Contains(result, ", ") {
t.Error("should add space after comma")
}
}

View File

@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
// Note: For modes like "nvfp4", qbiases will be nil.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
mlx.Eval(qw, scales, qbiases)
// Handle modes that don't return qbiases (e.g., nvfp4)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
// Supports multiple quantization modes:
// - "affine": scale + zero-point bias (QBiases required)
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (NOT layer bias)
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
}
return out
}
// MultiLinearLayer is an interface for per-head linear layers.
// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
type MultiLinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
}
// MultiLinear performs per-head linear projections.
// Weight shape: [num_heads, output_dims, input_dims]
// Input shape: [B, num_heads, L, input_dims]
// Output shape: [B, num_heads, L, output_dims]
type MultiLinear struct {
Weight *mlx.Array `weight:"weight"`
}
// NewMultiLinear creates a MultiLinear layer with the given weight.
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
return &MultiLinear{Weight: weight}
}
// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
// Weight: [num_heads, output_dims, input_dims]
// x: [B, num_heads, L, input_dims]
// wT: [num_heads, input_dims, output_dims]
// Result: [B, num_heads, L, output_dims]
wT := mlx.Transpose(ml.Weight, 0, 2, 1)
return mlx.Matmul(x, wT)
}

View File

@@ -1,284 +0,0 @@
//go:build mlx
// Package runner provides a subprocess server for image generation.
// It listens on a port and handles HTTP requests for image generation.
package runner
import (
"context"
"encoding/json"
"flag"
"fmt"
"image"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
}
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// ImageEditModel extends ImageModel with image editing/conditioning capability.
// Models that support input images for editing should implement this interface.
type ImageEditModel interface {
ImageModel
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model ImageModel
modelName string
}
// Execute is the entry point for the image runner subprocess
func Execute(args []string) error {
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to image model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
err := mlx.InitMLX()
if err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(*modelName)
slog.Info("detected model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down image runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("image runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate and decode input images
const maxInputImages = 2
if len(req.Images) > maxInputImages {
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
return
}
var inputImages []image.Image
if len(req.Images) > 0 {
// TODO: add memory check for input images
inputImages = make([]image.Image, len(req.Images))
for i, imgBytes := range req.Images {
img, err := imagegen.DecodeImage(imgBytes)
if err != nil {
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
return
}
inputImages[i] = img
}
slog.Info("decoded input images", "count", len(inputImages))
// Default width/height to first input image dimensions, scaled to max 1024
bounds := inputImages[0].Bounds()
w, h := bounds.Dx(), bounds.Dy()
if w > 1024 || h > 1024 {
if w > h {
h = h * 1024 / w
w = 1024
} else {
w = w * 1024 / h
h = 1024
}
}
req.Width = int32(w)
req.Height = int32(h)
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps
// Only seed needs to be set here if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Generate image using the common interface
ctx := r.Context()
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
var img *mlx.Array
var err error
if len(inputImages) > 0 {
editModel, ok := s.model.(ImageEditModel)
if !ok {
http.Error(w, "model does not support image editing", http.StatusBadRequest)
return
}
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
} else {
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
}
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

View File

@@ -17,17 +17,31 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "FP4", "FP8", or ""
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
GroupSize() int // Returns quantization group size, or 0 if not specified
}
// quantizationParams returns groupSize, bits, mode for a quantization type.
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
// QuantizationParams returns groupSize, bits, mode for a quantization type.
// MLX quantization modes:
// - "affine": scale + zero-point bias, group_size=32/64/128
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "FP4":
case "NVFP4":
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
// 4-bit quantization with affine mode (scale + qbias)
return 32, 4, "affine"
case "MXFP8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
// 8-bit quantization with affine mode (default for quantized models)
return 64, 8, "affine"
default:
return 32, 8, "affine" // FP8 or unknown
return 32, 8, "affine" // Default to affine
}
}
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
if field.Type == linearLayerType {
if !hasTag {
continue // no tag = skip
}
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
continue
}
// Handle nn.MultiLinearLayer interface fields specially
multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
if field.Type == multiLinearLayerType {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadMultiLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
return prefix + "." + suffix
}
// LoadMultiLinearLayer loads a per-head linear layer from weights.
// Weight shape should be [num_heads, output_dims, input_dims].
// If quantized, always dequantizes since batched quantized matmul isn't supported.
func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
hasScale := weights.HasTensor(scalePath)
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
if hasScale {
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Always dequantize for MultiLinear - no batched quantized matmul support
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, _ = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
}
return nn.NewMultiLinear(weight), nil
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
hasScale := weights.HasTensor(scalePath)
if hasScale {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := quantizationParams(weights.Quantization())
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
if mlx.MetalIsAvailable() {
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
mode := "affine"
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, mode = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
// NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,

Some files were not shown because too many files have changed in this diff Show More