Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
07d944fdfc move tokenizer to separate package 2026-01-28 14:25:31 -08:00
154 changed files with 3262 additions and 13181 deletions

View File

@@ -358,7 +358,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
@@ -466,7 +465,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
### Cloud

3
anthropic/anthropic.go Executable file → Normal file
View File

@@ -211,7 +211,6 @@ type MessageDelta struct {
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
@@ -722,7 +721,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
})
}
c.inputTokens = r.Metrics.PromptEvalCount
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
@@ -734,7 +732,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
StopReason: stopReason,
},
Usage: DeltaUsage{
InputTokens: c.inputTokens,
OutputTokens: c.outputTokens,
},
},

20
anthropic/anthropic_test.go Executable file → Normal file
View File

@@ -642,7 +642,7 @@ func TestStreamConverter_Basic(t *testing.T) {
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
@@ -650,24 +650,6 @@ func TestStreamConverter_Basic(t *testing.T) {
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_delta" {
if data, ok := e.Data.(MessageDeltaEvent); ok {
if data.Type != "message_delta" {
t.Errorf("unexpected data type: %+v", data)
}
if data.Delta.StopReason != "end_turn" {
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
}
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", data.Usage)
}
} else {
t.Errorf("unexpected data: %+v", e.Data)
}
}
if e.Event == "message_stop" {
hasStop = true
}

View File

@@ -29,7 +29,6 @@ import (
"github.com/containerd/console"
"github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter"
"github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
@@ -53,7 +52,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -664,7 +663,6 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
fmt.Println()
if aErr.SigninURL != "" {
_ = browser.OpenURL(aErr.SigninURL)
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
@@ -1890,7 +1888,7 @@ func NewCLI() *cobra.Command {
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
Short: "Start Ollama",
Short: "Start ollama",
Args: cobra.ExactArgs(0),
RunE: RunServer,
}

View File

@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "Q8",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization Q8 \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +

View File

@@ -6,8 +6,6 @@ import (
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
@@ -15,13 +13,11 @@ type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
func (c *Claude) args(model string) []string {
if model != "" {
args = append(args, "--model", model)
return []string{"--model", model}
}
args = append(args, extra...)
return args
return nil
}
func (c *Claude) findPath() (string, error) {
@@ -43,18 +39,18 @@ func (c *Claude) findPath() (string, error) {
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
func (c *Claude) Run(model string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd := exec.Command(claudePath, c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)

View File

@@ -84,21 +84,17 @@ func TestClaudeArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}

View File

@@ -14,21 +14,20 @@ type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string, extra []string) []string {
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, args []string) error {
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model, args)...)
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -11,20 +11,17 @@ func TestCodexArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
@@ -21,14 +20,6 @@ type config struct {
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config.json"), nil
}
func legacyConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
@@ -36,46 +27,6 @@ func legacyConfigPath() (string, error) {
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
func migrateConfig() (bool, error) {
oldPath, err := legacyConfigPath()
if err != nil {
return false, err
}
oldData, err := os.ReadFile(oldPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
newPath, err := configPath()
if err != nil {
return false, err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return false, err
}
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
return false, fmt.Errorf("write new config: %w", err)
}
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
@@ -83,11 +34,6 @@ func load() (*config, error) {
}
data, err := os.ReadFile(path)
if err != nil && os.IsNotExist(err) {
if migrated, merr := migrateConfig(); merr == nil && migrated {
data, err = os.ReadFile(path)
}
}
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil

View File

@@ -200,10 +200,12 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
dir := filepath.Join(tmpDir, ".ollama")
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
@@ -265,7 +267,7 @@ func TestConfigPath(t *testing.T) {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config.json")
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
@@ -320,183 +322,6 @@ func TestLoad(t *testing.T) {
})
}
func TestMigrateConfig(t *testing.T) {
t.Run("migrates legacy file to new location", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
newPath, _ := configPath()
got, err := os.ReadFile(newPath)
if err != nil {
t.Fatalf("new config not found: %v", err)
}
if string(got) != string(data) {
t.Errorf("content mismatch: got %s", got)
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("legacy file should have been removed")
}
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
t.Error("legacy directory should have been removed")
}
})
t.Run("no-op when no legacy file exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("expected no migration")
}
})
t.Run("skips corrupt legacy file", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("should not migrate corrupt file")
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
t.Error("corrupt legacy file should not have been deleted")
}
})
t.Run("new path takes precedence over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
newDir := filepath.Join(tmpDir, ".ollama")
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if _, ok := cfg.Integrations["new"]; !ok {
t.Error("expected new-path integration to be loaded")
}
if _, ok := cfg.Integrations["old"]; ok {
t.Error("legacy integration should not have been loaded")
}
})
t.Run("idempotent when called twice", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("second migration should be a no-op")
}
})
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
t.Error("directory with other files should not have been removed")
}
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
t.Error("other files in legacy directory should be untouched")
}
})
t.Run("save writes to new path after migration", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
// load triggers migration, then save should write to new path
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
t.Fatal(err)
}
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
if _, err := os.Stat(newPath); os.IsNotExist(err) {
t.Error("save should write to new path")
}
// old path should not be recreated
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("save should not recreate legacy path")
}
})
t.Run("load triggers migration transparently", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
t.Error("migration via load() did not preserve data")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -7,8 +7,6 @@ import (
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Droid implements Runner and Editor for Droid integration
@@ -39,7 +37,7 @@ type modelEntry struct {
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string, args []string) error {
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
@@ -53,7 +51,7 @@ func (d *Droid) Run(model string, args []string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid", args...)
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -119,7 +117,7 @@ func (d *Droid) Edit(models []string) error {
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: envconfig.Host().String() + "/v1",
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,

View File

@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
}
}
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
if model["baseUrl"] != "http://localhost:11434/v1" {
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
}
if model["apiKey"] != "ollama" {
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
{
"model": "existing-ollama-model",
"displayName": "existing-ollama-model",
"baseUrl": "http://127.0.0.1:11434/v1",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 64000,

View File

@@ -13,7 +13,6 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -23,7 +22,7 @@ import (
// Runner can run an integration with a model.
type Runner interface {
Run(model string, args []string) error
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
@@ -42,29 +41,9 @@ type Editor interface {
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Openclaw{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
}
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
func selectIntegration() (string, error) {
@@ -75,9 +54,6 @@ func selectIntegration() (string, error) {
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
@@ -106,25 +82,62 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
if len(items) == 0 {
return nil, fmt.Errorf("no models available")
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
@@ -138,27 +151,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
selected = []string{model}
}
var toPull []string
for _, m := range selected {
if !existingModels[m] {
toPull = append(toPull, m)
}
}
if len(toPull) > 0 {
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
if ok, err := confirmPrompt(msg); err != nil {
return nil, err
} else if !ok {
return nil, errCancelled
}
for _, m := range toPull {
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, m); err != nil {
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
}
}
}
// if any model in selected is a cloud model, ensure signed in
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
@@ -228,13 +221,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selected, nil
}
func runIntegration(name, modelName string, args []string) error {
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
return r.Run(modelName)
}
// LaunchCmd returns the cobra command for launching integrations.
@@ -243,7 +236,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
@@ -252,43 +245,19 @@ Supported integrations:
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
// No "--" separator: only allow 0 or 1 args (integration name)
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
if len(args) > 0 {
name = args[0]
} else {
// "--" was used: args before it = integration name, args after = passthrough
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
@@ -304,14 +273,16 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], passArgs)
return runIntegration(name, config.Models[0])
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
@@ -366,13 +337,13 @@ Examples:
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0], passArgs)
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0], passArgs)
return runIntegration(name, models[0])
},
}
@@ -380,154 +351,3 @@ Examples:
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}
type modelInfo struct {
Name string
Remote bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
for _, rec := range recommendedModels {
recommended[rec.Name] = true
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModel(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Non-existing models get "install?" suffix and are pushed to the bottom.
// When user has no models, preserve recommended order.
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] {
notInstalled[items[i].Name] = true
if items[i].Description != "" {
items[i].Description += ", install?"
} else {
items[i].Description = "install?"
}
}
}
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if !ac && !bc && aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model}
return client.Pull(ctx, &request, fn)
}

View File

@@ -1,12 +1,10 @@
package config
import (
"fmt"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
)
@@ -92,8 +90,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -123,7 +121,7 @@ func TestLaunchCmd(t *testing.T) {
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model", nil)
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
@@ -176,336 +174,15 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
var _ func(string, []string) error = r.Run
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
})
}
}
func TestParseArgs(t *testing.T) {
// Tests reflect cobra's ArgsLenAtDash() semantics:
// - cobra strips "--" from args
// - ArgsLenAtDash() returns the index where "--" was, or -1
tests := []struct {
name string
args []string // args as cobra delivers them (no "--")
dashIdx int // what ArgsLenAtDash() returns
wantName string
wantArgs []string
wantErr bool
}{
{
name: "no extra args, no dash",
args: []string{"claude"},
dashIdx: -1,
wantName: "claude",
},
{
name: "with extra args after --",
args: []string{"codex", "-p", "myprofile"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"-p", "myprofile"},
},
{
name: "extra args only after --",
args: []string{"codex", "--sandbox", "workspace-write"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"--sandbox", "workspace-write"},
},
{
name: "-- at end with no args after",
args: []string{"claude"},
dashIdx: 1,
wantName: "claude",
},
{
name: "-- with no integration name",
args: []string{"--verbose"},
dashIdx: 0,
wantName: "",
wantArgs: []string{"--verbose"},
},
{
name: "multiple args before -- is error",
args: []string{"claude", "codex", "--verbose"},
dashIdx: 2,
wantErr: true,
},
{
name: "multiple args without -- is error",
args: []string{"claude", "codex"},
dashIdx: -1,
wantErr: true,
},
{
name: "no args, no dash",
args: []string{},
dashIdx: -1,
wantName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the parsing logic from LaunchCmd using dashIdx
var name string
var parsedArgs []string
var err error
dashIdx := tt.dashIdx
args := tt.args
if dashIdx == -1 {
if len(args) > 1 {
err = fmt.Errorf("unexpected arguments: %v", args[1:])
} else if len(args) == 1 {
name = args[0]
}
} else {
if dashIdx > 1 {
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
} else {
if dashIdx == 1 {
name = args[0]
}
parsedArgs = args[dashIdx:]
}
}
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if name != tt.wantName {
t.Errorf("name = %q, want %q", name, tt.wantName)
}
if !slices.Equal(parsedArgs, tt.wantArgs) {
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
}
})
}
}
func TestIsCloudModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
}
})
}
}
func names(items []selectItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)
}
return out
}
func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
}
for _, item := range items {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
}
}
}
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "qwen2.5:latest", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
}
}
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
}
}
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
got := names(items)
if got[0] != "llama3.2" {
t.Errorf("pre-checked model should be first, got %v", got)
}
}
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
switch item.Name {
case "glm-4.7-flash", "glm-4.7:cloud":
if strings.HasSuffix(item.Description, "install?") {
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
}
case "kimi-k2.5:cloud", "qwen3:8b":
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
}
}
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "kimi-k2.5:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// kimi-k2.5:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
}
for _, item := range items {
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_LatestTagStripped(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash:latest", Remote: false},
{Name: "llama3.2:latest", Remote: false},
}
items, _, existingModels, _ := buildModelList(existing, nil, "")
got := names(items)
// :latest should be stripped from display names
for _, name := range got {
if strings.HasSuffix(name, ":latest") {
t.Errorf("name %q should not have :latest suffix", name)
}
}
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
count := 0
for _, name := range got {
if name == "glm-4.7-flash" {
count++
}
}
if count != 1 {
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
}
// Stripped name should be in existingModels so it won't be pulled
if !existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should be in existingModels")
}
}
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if !existingModels["llama3.2"] {
t.Error("llama3.2 should be in existingModels")
}
if !existingModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in existingModels")
}
if existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
}
if !cloudModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in cloudModels")
}
if !cloudModels["kimi-k2.5:cloud"] {
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
}
if cloudModels["llama3.2"] {
t.Error("llama3.2 should not be in cloudModels")
}
}

View File

@@ -1,254 +0,0 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
cmd := exec.Command(bin, "onboard",
"--auth-choice", "skip",
"--gateway-token", "ollama",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Openclaw) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

View File

@@ -1,878 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestOpenclawIntegration(t *testing.T) {
c := &Openclaw{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "OpenClaw" {
t.Errorf("String() = %q, want %q", got, "OpenClaw")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelExists(t, configPath, "mistral")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
})
}
func TestOpenclawModels(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertOpenclawModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestOpenclawPaths(t *testing.T) {
c := &Openclaw{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestOpenclawModelsEdgeCases(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestOpenclawEditSchemaFields(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestOpenclawEditModelNames(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "library/model:tag")
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "test-model")
})
}
func TestOpenclawEditAgentsPreservation(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testOpenclawFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestOpenclawEdit_RoundTrip(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestOpenclawEdit_Idempotent(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestOpenclawEdit_BackupCreated(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestOpenclawClawdbotAlias(t *testing.T) {
for _, alias := range []string{"clawdbot", "moltbot"} {
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
r, ok := integrations[alias]
if !ok {
t.Fatalf("%s not found in integrations", alias)
}
if _, ok := r.(*Openclaw); !ok {
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
}
})
t.Run(alias+" is hidden from selector", func(t *testing.T) {
if !integrationAliases[alias] {
t.Errorf("%s should be in integrationAliases", alias)
}
})
}
}
func TestOpenclawLegacyPaths(t *testing.T) {
c := &Openclaw{}
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
t.Errorf("expected legacy path, got %s", paths[0])
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(newDir, "openclaw.json") {
t.Errorf("expected new path, got %s", paths[0])
}
})
t.Run("Models reads from legacy path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("expected [llama3.2], got %v", models)
}
})
t.Run("Models prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "new-model" {
t.Errorf("expected [new-model], got %v", models)
}
})
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "new" {
t.Errorf("expected theme from new config, got %v", cfg["theme"])
}
})
t.Run("Edit migrates from legacy config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Should write to new path
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
data, err := os.ReadFile(newPath)
if err != nil {
t.Fatal("expected new config file to be created")
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("legacy theme setting was not migrated")
}
})
}
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}
func TestOpenclawOnboarded(t *testing.T) {
c := &Openclaw{}
t.Run("returns false when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if c.onboarded() {
t.Error("expected false when no config exists")
}
})
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
if c.onboarded() {
t.Error("expected false when no wizard section")
}
})
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is missing")
}
})
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is empty")
}
})
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when wizard.lastRunAt is set")
}
})
t.Run("checks legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when legacy config has wizard.lastRunAt")
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
// New path has no wizard marker
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
// Legacy has wizard marker
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if c.onboarded() {
t.Error("expected false - should prefer new path which has no wizard marker")
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
if c.onboarded() {
t.Error("expected false for corrupted JSON")
}
})
t.Run("handles wrong type for wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard is wrong type")
}
})
}

View File

@@ -9,8 +9,6 @@ import (
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
@@ -18,7 +16,7 @@ type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
@@ -32,7 +30,7 @@ func (o *OpenCode) Run(model string, args []string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode", args...)
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -90,7 +88,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
"baseURL": "http://localhost:11434/v1",
},
}
}

View File

@@ -1,196 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, map[string]any{
"id": model,
"_launch": true,
})
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}

View File

@@ -1,609 +0,0 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}

View File

@@ -275,11 +275,7 @@ func parseInput(r io.Reader) (inputEvent, byte, error) {
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
@@ -318,11 +314,7 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
@@ -353,15 +345,10 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
suffix = " " + ansiGray + "(default)" + ansiReset
}
desc := ""
if item.Description != "" {
desc = " " + ansiGray + "- " + item.Description + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
}
lineCount++
}

View File

@@ -313,12 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,455 +0,0 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
//
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
rotaryDim := int(float32(headDim) * partialRotaryFactor)
if rotaryDim%2 != 0 {
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
}
// Handle 1D (bias) or 2D (weight) tensors
is1D := len(shape) == 1
var inFeatures int
if is1D {
inFeatures = 1
} else {
inFeatures = int(shape[1])
}
outFeatures := int(shape[0])
nEffectiveHeads := outFeatures / headDim
if nEffectiveHeads != nHeads {
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
}
// Reshape to [n_heads, head_dim, in_features]
reshaped := make([]float32, len(data))
copy(reshaped, data)
// Permute the rotary dimensions: even indices first, then odd
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
result := make([]float32, len(data))
halfRotary := rotaryDim / 2
for h := range nEffectiveHeads {
for f := range inFeatures {
for i := range halfRotary {
// Even dim (0, 2, 4, ...) -> position i
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
dstIdx := h*headDim*inFeatures + i*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
}
// Non-rotary part: copy as-is
for i := rotaryDim; i < headDim; i++ {
srcIdx := h*headDim*inFeatures + i*inFeatures + f
result[srcIdx] = reshaped[srcIdx]
}
}
}
return result, nil
}
}
type glmOcrModel struct {
ModelParameters
TextConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
RMSNormEps float32 `json:"rms_norm_eps"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeParameters struct {
RopeType string `json:"rope_type"`
MRopeSection []int32 `json:"mrope_section"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
Depth uint32 `json:"depth"`
NumHeads uint32 `json:"num_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
OutHiddenSize uint32 `json:"out_hidden_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
ImageStartTokenID uint32 `json:"image_start_token_id"`
ImageEndTokenID uint32 `json:"image_end_token_id"`
VideoStartTokenID uint32 `json:"video_start_token_id"`
VideoEndTokenID uint32 `json:"video_end_token_id"`
ImageTokenID uint32 `json:"image_token_id"`
VideoTokenID uint32 `json:"video_token_id"`
// Preprocessor config (preprocessor_config.json)
Preprocessor struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"-"`
}
var _ ModelConverter = (*glmOcrModel)(nil)
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.Preprocessor)
}
func (m *glmOcrModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "glmocr"
// Text model parameters
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
kv["glmocr.attention.key_length"] = headDim
kv["glmocr.attention.value_length"] = headDim
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
}
// Vision model parameters
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
// Preprocessor-derived image settings (min/max pixels and normalization)
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
if m.Preprocessor.Size.ShortestEdge > 0 {
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
}
if m.Preprocessor.Size.LongestEdge > 0 {
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
}
if len(m.Preprocessor.ImageMean) == 3 {
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
}
if len(m.Preprocessor.ImageStd) == 3 {
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
}
// Special tokens
kv["glmocr.image_token_id"] = m.ImageTokenID
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
kv["glmocr.video_token_id"] = m.VideoTokenID
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
return kv
}
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
skipLayer := func(name string) bool {
// Tensor names are already replaced to "blk.N.xxx" format
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(name)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return blkNum >= numLayers
}
for _, t := range ts {
name := t.Name()
// Skip next-n prediction layers (layers >= num_hidden_layers)
if skipLayer(name) {
continue
}
// Split ffn_gate_up into separate gate and up projections
if strings.Contains(name, "ffn_gate_up") {
for t := range splitDim(t, 0,
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
) {
out = append(out, t)
}
continue
}
if strings.HasSuffix(name, "patch_embd.weight") {
shape := t.Shape()
if len(shape) == 5 && shape[2] == 2 {
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
t0 := t.Clone()
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t0,
})
t1 := t.Clone()
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t1,
})
continue
}
if len(shape) == 4 {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
// Fall through to default handling
}
// Handle pre-split patch embedding weights
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
if strings.Contains(name, "patch_embd.0.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.Contains(name, "patch_embd.1.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Handle .weight.0 and .weight.1 suffix patterns
if strings.HasSuffix(name, "patch_embd.weight.0") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.HasSuffix(name, "patch_embd.weight.1") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
// We permute at conversion time so the weights work correctly with GGML's kernel
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
// Get config values for permutation
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
// Use appropriate head count: nHeads for Q, nKVHeads for K
effectiveHeads := nHeads
if strings.Contains(name, "attn_k.") {
effectiveHeads = nKVHeads
}
permutedT := t.Clone()
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: permutedT,
})
continue
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *glmOcrModel) Replacements() []string {
return []string{
// Vision encoder
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
"model.visual.patch_embed.proj", "v.patch_embd",
"model.visual.blocks", "v.blk",
"model.visual.post_layernorm", "v.post_ln",
"model.visual.downsample", "mm.patch_merger",
// Vision attention
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"attn.q_norm", "attn_q_norm",
"attn.k_norm", "attn_k_norm",
// Vision norms
"norm1", "ln1",
"norm2", "ln2",
// Vision MLP
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
// Merger (multimodal projector)
"model.visual.merger.proj", "mm.model.fc",
"model.visual.merger.post_projection_norm", "mm.post_norm",
"model.visual.merger.gate_proj", "mm.gate",
"model.visual.merger.up_proj", "mm.up",
"model.visual.merger.down_proj", "mm.down",
// Language model
"model.language_model.embed_tokens", "token_embd",
"model.language_model.layers", "blk",
"model.language_model.norm", "output_norm",
"lm_head", "output",
// Language model attention
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
// Language model norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"post_self_attn_layernorm", "post_attn_norm",
"post_mlp_layernorm", "post_ffn_norm",
// Language model MLP (remove mlp. prefix so ffn_* names work)
"mlp.gate_up_proj", "ffn_gate_up",
"mlp.down_proj", "ffn_down",
}
}

View File

@@ -1,512 +0,0 @@
package convert
import (
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
if q.NumAttentionHeads == 0 {
return fmt.Errorf("qwen3next: num_attention_heads must be set")
}
if q.NumKeyValueHeads == 0 {
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
}
if q.HeadDim == 0 {
return fmt.Errorf("qwen3next: head_dim must be set")
}
if q.RopeTheta == 0 {
return fmt.Errorf("qwen3next: rope_theta must be set")
}
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
}
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return nil
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
}
kv["attention.key_length"] = headDim
kv["attention.value_length"] = headDim
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
if q.SharedExpertIntermSize > 0 {
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
return kv
}
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
continue
}
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
}
return out
}
type qkvzSplitSpec struct {
hidden int
headKDim int
headVDim int
numKHeads int
numVHeads int
qkvzDim int
qkvOut int
gateOut int
}
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
if len(shape) != 2 {
return qkvzSplitSpec{}, false
}
numKHeads := int(q.LinearNumKeyHeads)
numVHeads := int(q.LinearNumValueHeads)
headKDim := int(q.LinearKeyHeadDim)
headVDim := int(q.LinearValueHeadDim)
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
return qkvzSplitSpec{}, false
}
if numVHeads%numKHeads != 0 {
return qkvzSplitSpec{}, false
}
hidden := int(shape[1])
vPerHead := headVDim * (numVHeads / numKHeads)
qkvzDim := 2*headKDim + 2*vPerHead
expectedOut := qkvzDim * numKHeads
if int(shape[0]) != expectedOut {
return qkvzSplitSpec{}, false
}
return qkvzSplitSpec{
hidden: hidden,
headKDim: headKDim,
headVDim: headVDim,
numKHeads: numKHeads,
numVHeads: numVHeads,
qkvzDim: qkvzDim,
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
gateOut: headVDim * numVHeads,
}, true
}
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
spec, ok := q.qkvzSpec(t.Shape())
if !ok {
return nil, nil, false
}
qkvTensor := t.Clone()
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
gateTensor := t.Clone()
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
return &ggml.Tensor{
Name: qkvName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
WriterTo: qkvTensor,
}, &ggml.Tensor{
Name: gateName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
WriterTo: gateTensor,
}, true
}
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
return nil, err
}
offset := 0
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
offset += vPerHead
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
var out tensor.Tensor
if extractGate {
out = zMat
} else {
out, err = tensor.Concat(1, qMat, kMat, vMat)
if err != nil {
return nil, err
}
}
out = tensor.Materialize(out)
out, err = tensor.Transpose(out, 1, 0)
if err != nil {
return nil, err
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(out.(*tensor.Dense))
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
n, err := n.Add(ones)
if err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 0)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
// Full attention (self_attn)
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
"linear_attn.A_log", "ssm_a",
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -41,7 +41,6 @@ func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -99,8 +99,6 @@ func (st safetensor) Kind() uint32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
!strings.HasPrefix(st.name, "mm.") &&
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -71,10 +71,6 @@
{
"source": "/api",
"destination": "/api/introduction"
},
{
"source": "/integrations/clawdbot",
"destination": "/integrations/openclaw"
}
],
"navigation": {
@@ -107,7 +103,6 @@
"pages": [
"/integrations/claude-code",
"/integrations/cline",
"/integrations/openclaw",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",

View File

@@ -10,7 +10,6 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
@@ -164,4 +163,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -134,12 +134,22 @@ success
### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0`
#### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com

View File

@@ -1,50 +0,0 @@
---
title: OpenClaw
---
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch openclaw
```
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch openclaw --config
```
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -9,7 +9,7 @@ OpenCode is an open-source AI coding assistant that runs in your terminal.
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install | bash
curl -fsSL https://opencode.ai/install.sh | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>

View File

@@ -201,7 +201,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},

View File

@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 0,
"": 4096,
"2048": 2048,
}

View File

@@ -268,10 +268,8 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"glmocr",
"lfm2",
}, kv.Architecture())
}
@@ -861,13 +859,11 @@ func (f GGML) FlashAttention() bool {
"bert",
"gemma3",
"glm4moelite",
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))
}

1
go.mod
View File

@@ -27,7 +27,6 @@ require (
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0

3
go.sum
View File

@@ -174,8 +174,6 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -306,7 +304,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=

View File

@@ -75,10 +75,3 @@ type Cache interface {
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
// CheckpointCache optionally supports restoring recurrent state to a prior
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
// The returned position is the number of tokens that can be kept (prefix length).
type CheckpointCache interface {
PrepareRestore(seq int, targetPos int32) (int32, bool)
}

View File

@@ -1,276 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan <jmorganca@gmail.com>
Date: Tue, 3 Feb 2026 12:00:00 -0800
Subject: [PATCH] ggml: metal solve_tri
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
7 files changed, 177 insertions(+)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 680904d13..83385c9ef 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_SOLVE_TRI);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_solve_tri_f32");
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index 0a8b9211a..8a9d17460 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index 7b5ee968c..4e5acfbe5 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_SOLVE_TRI:
+ return ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32;
+ case GGML_OP_COUNT_EQUAL:
+ return has_simdgroup_reduction &&
+ op->src[0]->type == GGML_TYPE_I32 &&
+ op->src[1]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 8944b07e9..cfdea9c07 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int64_t ne00;
int64_t ne01;
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 80864f303..4ac135603 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ const int64_t ncols = ne10;
+ const int64_t n_batches = (int64_t)ne02 * ne03;
+ const int64_t nr = n_batches * ncols;
+
+ int nth = 64;
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ if (nth < 1) {
+ nth = 1;
+ }
+
+ const int64_t n_tg = (nr + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
index 902b54452..a475183d3 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d33c16079..c37447a10 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const uint64_t ncols = (uint64_t) args.ne10;
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
+ const uint64_t nr = n_batches * ncols;
+
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
+ if (gid >= nr) {
+ return;
+ }
+
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
+ const uint64_t i02 = rem / ncols;
+ const uint64_t i01 = rem - i02 * ncols;
+
+ const uint64_t sa0 = args.nb00 / sizeof(float);
+ const uint64_t sa1 = args.nb01 / sizeof(float);
+ const uint64_t sa2 = args.nb02 / sizeof(float);
+ const uint64_t sa3 = args.nb03 / sizeof(float);
+
+ const uint64_t sb0 = args.nb10 / sizeof(float);
+ const uint64_t sb1 = args.nb11 / sizeof(float);
+ const uint64_t sb2 = args.nb12 / sizeof(float);
+ const uint64_t sb3 = args.nb13 / sizeof(float);
+
+ const uint64_t sx0 = args.nb0 / sizeof(float);
+ const uint64_t sx1 = args.nb1 / sizeof(float);
+ const uint64_t sx2 = args.nb2 / sizeof(float);
+ const uint64_t sx3 = args.nb3 / sizeof(float);
+
+ device const float * A = (device const float *) src0;
+ device const float * B = (device const float *) src1;
+ device float * X = (device float *) dst;
+
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
+
+ const uint64_t n = (uint64_t) args.ne11;
+
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
+ float sum = 0.0f;
+ for (uint64_t t = 0; t < i00; ++t) {
+ sum += A[A_base + i00 * sa1 + t * sa0] *
+ X[X_base + t * sx1 + i01 * sx0];
+ }
+
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
+ X[X_base + i00 * sx1 + i01 * sx0] =
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
+ }
+}
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -80,7 +81,6 @@ type LlamaServer interface {
GetPort() int
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
HasExited() bool
ContextLength() int
}
// llmServer is an instance of a runner hosting a single model
@@ -116,7 +116,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -142,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tokenizer tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tokenizer, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -155,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tokenizer == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -211,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tokenizer == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -261,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tokenizer != nil,
modelPath,
gpuLibs,
status,
@@ -310,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tokenizer != nil {
return &ollamaServer{llmServer: s, tokenizer: tokenizer}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1201,8 +1201,7 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
resp, err := http.DefaultClient.Do(r)
if err != nil {
slog.Error("do load request", "error", err)
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
return nil, fmt.Errorf("do load request: %w", err)
}
defer resp.Body.Close()
@@ -1774,7 +1773,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1809,7 +1808,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}
@@ -1903,10 +1902,6 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
return 0
}
func (s *llmServer) ContextLength() int {
return s.options.NumCtx
}
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
devices, err := ml.GetDevicesFromRunner(ctx, s)
if err != nil {

View File

@@ -170,7 +170,6 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
GELU_ERF(ctx Context) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
@@ -207,32 +206,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Exp(ctx Context) Tensor
Neg(ctx Context) Tensor
// Clamp clamps values to [min, max] range
Clamp(ctx Context, min, max float32) Tensor
// Softplus computes ln(1 + exp(x))
Softplus(ctx Context) Tensor
// CumSum computes cumulative sum along dimension 0
CumSum(ctx Context) Tensor
// Diag creates a diagonal matrix from a 1D tensor
Diag(ctx Context) Tensor
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
Tri(ctx Context, triType int) Tensor
// Fill fills a tensor with a constant value (in-place)
Fill(ctx Context, value float32) Tensor
// Repeat4D repeats tensor to match target shape
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
// SolveTri solves a triangular system Ax = B
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}

View File

@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
}
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
sched := C.ggml_backend_sched_new_ext(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
@@ -1581,13 +1581,6 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
@@ -1779,76 +1772,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
}
}
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
}
}
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
}
}
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
}
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {

View File

@@ -1370,26 +1370,6 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SOLVE_TRI);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_solve_tri_f32");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);

View File

@@ -133,7 +133,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@@ -1023,17 +1023,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:

View File

@@ -2385,27 +2385,6 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -5834,66 +5813,6 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -500,27 +500,6 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;

View File

@@ -357,10 +357,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
} break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2935,65 +2931,6 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_solve_tri args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
const int64_t ncols = ne10;
const int64_t n_batches = (int64_t)ne02 * ne03;
const int64_t nr = n_batches * ncols;
int nth = 64;
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
if (nth < 1) {
nth = 1;
}
const int64_t n_tg = (nr + nth - 1) / nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -68,7 +68,6 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);

View File

@@ -3012,66 +3012,6 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

410
model/ignore_test.go Normal file
View File

File diff suppressed because one or more lines are too long

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -56,18 +56,6 @@ type fakeTensor struct {
Name string
}
// Stub methods to satisfy ml.Tensor interface
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,174 +0,0 @@
package glmocr
import (
"image"
"log/slog"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
temporalPatchSize int
spatialMergeSize int
minPixels int
maxPixels int
factor int
imageMean [3]float32
imageStd [3]float32
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
// Read normalization values from config if available, otherwise use CLIP defaults
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 336)),
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
spatialMergeSize: spatialMergeSize,
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
factor: patchSize * spatialMergeSize,
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
}
}
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
temporalFactor := p.temporalPatchSize
numFrames := temporalFactor // single image
if height < factor || width < factor {
// Scale up small images
scale := float64(factor) / float64(min(height, width))
height = int(math.Ceil(float64(height) * scale))
width = int(math.Ceil(float64(width) * scale))
}
if temporalFactor <= 0 {
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
temporalFactor = 1
}
if numFrames < temporalFactor {
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
numFrames = temporalFactor
}
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
if tBar*hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if tBar*hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
img = imageproc.Composite(img)
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
// We keep [C, H, W] for patch extraction
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
// Calculate grid dimensions (after Conv2D patching)
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // Single image
ImageHeight: resizedHeight,
ImageWidth: resizedWidth,
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, err
}
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := 3
patchSize := p.patchSize
mergeSize := p.spatialMergeSize
temporalPatchSize := p.temporalPatchSize
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + (py * patchSize) + px
result[dstIdx] = pixels[srcIdx]
}
}
if temporalPatchSize > 1 {
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -1,235 +0,0 @@
package glmocr
import (
"bytes"
"errors"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
PatchMerger *PatchMerger `gguf:"mm"`
ImageProcessor
imageTokenID int32
imageStartTokenID int32
imageEndTokenID int32
}
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: allEOS,
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: int32(c.Uint("image_token_id", 59280)),
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Blocks) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
// Create pixel values tensor from flattened patches
// Shape: [patchDim, numPatches]
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
// Forward through vision encoder
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
// Forward through downsample (patch merger)
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
return nil, errors.New("glmocr: missing vision downsample weights")
}
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
if m.PatchMerger == nil {
return nil, errors.New("glmocr: missing patch merger weights")
}
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
// Reset position cache
m.TextModel.positionCache = m.TextModel.positionCache[:0]
m.TextModel.ropeDelta = 0
pos := int32(0)
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
continue
}
// Get grid info for position calculation
grid := inp.Multimodal[0].Data.(*Grid)
mergedH := grid.Height / m.VisionModel.spatialMergeSize
mergedW := grid.Width / m.VisionModel.spatialMergeSize
// Add image start token
result = append(result, &input.Input{Token: m.imageStartTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
// Add image tokens with multimodal data
// All image tokens share the same base position for temporal dimension
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
basePos := pos
sameBatch := tokensPerGrid - 1
if sameBatch < 0 {
sameBatch = 0
}
result = append(result, &input.Input{
Token: m.imageTokenID,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: sameBatch,
})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
// Add placeholder tokens for remaining positions
// All image tokens use the same base position (temporal stays constant)
for range tokensPerGrid - 1 {
result = append(result, &input.Input{Token: m.imageTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
}
// Advance position by max(mergedH, mergedW) after image tokens
pos = basePos + int32(max(mergedH, mergedW))
// Add image end token
result = append(result, &input.Input{Token: m.imageEndTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
}
// Compute rope delta for continuation after the prefill segment:
// delta = (max_position_id + 1) - sequence_length
if len(m.TextModel.positionCache) > 0 {
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
ctx.Forward(hiddenStates)
// Build position slices for M-RoPE
positionSlice := func() [][]int32 {
s := [][]int32{
make([]int32, len(batch.Positions)), // temporal
make([]int32, len(batch.Positions)), // height
make([]int32, len(batch.Positions)), // width
make([]int32, len(batch.Positions)), // unused (zeros)
}
for i, position := range batch.Positions {
// Translate through position cache or continue sequence
if position < int32(len(m.TextModel.positionCache)) {
position = m.TextModel.positionCache[position]
} else if len(m.TextModel.positionCache) > 0 {
// Continue sequence after cached positions using ropeDelta
position = position + m.TextModel.ropeDelta
}
s[0][i] = position
s[1][i] = position
s[2][i] = position
}
return s
}()
// Inject vision embeddings and adjust positions for image tokens
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
w := grid.Width / m.VisionModel.spatialMergeSize
for i := range img.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
// Process through transformer layers
for i, layer := range m.TextModel.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glmocr", New)
}

View File

@@ -1,190 +0,0 @@
package glmocr
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type TextModelOptions struct {
hiddenSize int
numHeads int
numKVHeads int
headDim int
rotaryDim int
intermediateSize int
eps float32
ropeBase float32
mropeSections []int
}
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
// With 4 sections for [temporal, height, width, unused]
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Separate Q, K, V projections
q := sa.Query.Forward(ctx, hiddenStates)
k := sa.Key.Forward(ctx, hiddenStates)
v := sa.Value.Forward(ctx, hiddenStates)
// Reshape for GQA
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
// Apply M-RoPE (multi-resolution rotary position embeddings)
q = opts.applyMRoPE(ctx, q, positions)
k = opts.applyMRoPE(ctx, k, positions)
// Scaled dot-product attention with KV cache
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
return sa.Output.Forward(ctx, kqv)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type TextDecoderLayer struct {
// Input layernorm (before attention)
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
// Post self-attention layernorm (after attention, before residual add)
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
// FFN input layernorm (after first residual, before MLP)
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
// Post MLP layernorm (after MLP, before residual add)
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
}
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
// Attention block
residual := hiddenStates
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
// Prune to output positions in final layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP block
residual = hiddenStates
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextDecoderLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextModelOptions
// positionCache stores the M-RoPE position for each token in the sequence.
// This is needed because image tokens share the same base position but have
// different height/width offsets, and the end token position depends on the
// image grid dimensions.
positionCache []int32
ropeDelta int32
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// Clear position cache when KV cache shifts
m.positionCache = nil
m.ropeDelta = 0
return m.applyMRoPE(ctx, key, shift), nil
}
func newTextModel(c fs.Config) *TextModel {
hiddenSize := int(c.Uint("embedding_length", 1536))
numHeads := int(c.Uint("attention.head_count", 16))
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
intermediateSize := int(c.Uint("feed_forward_length", 4608))
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
ropeBase := c.Float("rope.freq_base", 10000)
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
if ropeDim <= 0 {
ropeDim = headDim
}
mropeSections := c.Ints("rope.mrope_section")
var sectionInts []int
if len(mropeSections) > 0 {
sectionInts = make([]int, len(mropeSections))
for i, section := range mropeSections {
sectionInts[i] = int(section)
}
} else {
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
// For rotaryDim=64 this yields [8, 12, 12].
total := ropeDim / 2
if total <= 0 {
total = 32
}
s0 := total * 2 / 8
s1 := total * 3 / 8
s2 := total - s0 - s1
sectionInts = []int{s0, s1, s2}
}
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
rotaryDim := ropeDim
return &TextModel{
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
TextModelOptions: &TextModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
rotaryDim: rotaryDim,
intermediateSize: intermediateSize,
eps: eps,
ropeBase: ropeBase,
mropeSections: sectionInts,
},
}
}

View File

@@ -1,355 +0,0 @@
package glmocr
import (
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type Grid struct {
Height int // Number of patches in height direction
Width int // Number of patches in width direction
Temporal int
ImageHeight int // Full image height in pixels
ImageWidth int // Full image width in pixels
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
numChannels int
patchSize int
temporalPatchSize int
imageSize int
spatialMergeSize int
outHiddenSize int
intermediateSize int
eps float32
}
type VisionPatchEmbed struct {
Proj *nn.Conv2D `gguf:"patch_embd_0"`
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
Bias ml.Tensor `gguf:"patch_embd.bias"`
}
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
_ = grid // patches are already in merge-block order
// pixelValues shape: [patchDim, numPatches]
numPatches := pixelValues.Shape()[1]
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Slice temporal frames for Conv2D (simulate Conv3D)
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
hiddenStates = hiddenStates.Add(ctx, out1)
}
// Flatten to [hidden_size, num_patches]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
if pe.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
}
return hiddenStates
}
type VisionSelfAttention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Combined QKV projection: [3*hidden_size, batch_size]
qkv := sa.QKV.Forward(ctx, hiddenStates)
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
// ChunkSections returns views - must make contiguous before further operations
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
q := chunks[0].Contiguous(ctx)
k := chunks[1].Contiguous(ctx)
v := chunks[2].Contiguous(ctx)
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
// Apply Q-norm and K-norm after head reshape
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
q = sa.QNorm.Forward(ctx, q, opts.eps)
k = sa.KNorm.Forward(ctx, k, opts.eps)
// Apply rotary position embeddings with vision-style 2D positions.
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
// We provide H/W sections and leave the remaining sections empty.
ropeFreqBase := float32(10000.0)
section := opts.headDim / 4
if section <= 0 {
section = 1
}
sections := []int{section, section, 0, 0}
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Try flash attention first (ScaledDotProductAttention), fall back to manual
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
slog.Warn("glmocr: vision attention falling back to manual attention",
"batchSize", batchSize, "numHeads", opts.numHeads,
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
// Manual attention fallback
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
// Attention scores
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, scale)
kq = kq.Softmax(ctx)
// Attention output: v @ kq (note: v first)
kqv := v.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type VisionBlock struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Pre-norm architecture
residual := hiddenStates
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionDownsample struct {
*nn.Conv2D
}
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
// Apply spatial downsampling via Conv2D
// Input: [hidden_size, num_patches] where patches are in merge-block order
if d.Conv2D == nil || d.Weight == nil {
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
return hiddenStates // Return input unchanged as fallback
}
merge := opts.spatialMergeSize
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
// ggml semantics: result.ne[perm[i]] = input.ne[i]
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// Step 3: Apply Conv2D without bias (bias added after reshape)
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
s0, s1 := merge, merge
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
// Step 5: Add bias after reshape
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
if d.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
}
return hiddenStates
}
type PatchMerger struct {
// GGUF tags align with mm.* keys used by the model
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
}
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Linear projection
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
// Post-projection layer norm + GELU ERF
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.GELU_ERF(ctx)
// Force a copy to avoid in-place mutation issues with GELU_ERF
hiddenStates = hiddenStates.Contiguous(ctx)
// SwiGLU MLP: down(silu(gate(x)) * up(x))
gateOut := m.GateProj.Forward(ctx, hiddenStates)
upOut := m.UpProj.Forward(ctx, hiddenStates)
gate := gateOut.SILU(ctx, upOut)
return m.DownProj.Forward(ctx, gate)
}
type VisionModel struct {
PatchEmbed *VisionPatchEmbed
Blocks []VisionBlock `gguf:"blk"`
PostLN *nn.RMSNorm `gguf:"post_ln"`
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings from flattened patches
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
// Create position IDs for RoPE (spatial grid)
// Patches are already in merge-block order from preprocessing
positions := m.createPositions(ctx, grid)
// Process through vision blocks
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
}
// Post-layernorm
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
// Note: Downsample is now applied separately in Model.EncodeMultimodal
// so mm.patch_merger remains a distinct module
return hiddenStates
}
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
// Create spatial position IDs for vision RoPE
// Position layout: [height, width, height, width] - 4 sections for mrope
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
// This follows the GLM-OCR rot_pos_emb layout
numPatches := grid.Height * grid.Width
mergeRatio := m.spatialMergeSize
// Build position arrays in merge-block order
// Each merge_ratio x merge_ratio block of patches is grouped together
hpos := make([]int32, numPatches)
wpos := make([]int32, numPatches)
ptr := 0
for y := 0; y < grid.Height; y += mergeRatio {
for x := 0; x < grid.Width; x += mergeRatio {
for dy := range mergeRatio {
for dx := range mergeRatio {
hpos[ptr] = int32(y + dy)
wpos[ptr] = int32(x + dx)
ptr++
}
}
}
}
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
// keep remaining sections zeroed to match its conventions.
zeros := make([]int32, numPatches)
s := [][]int32{
hpos, // Section 0: height
wpos, // Section 1: width
zeros, // Section 2: unused
zeros, // Section 3: unused
}
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
}
func newVisionModel(c fs.Config) *VisionModel {
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
patchSize := int(c.Uint("vision.patch_size", 14))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
imageSize := int(c.Uint("vision.image_size", 336))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
return &VisionModel{
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
numChannels: numChannels,
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
imageSize: imageSize,
spatialMergeSize: spatialMergeSize,
outHiddenSize: outHiddenSize,
intermediateSize: intermediateSize,
eps: eps,
},
}
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -8,7 +8,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
@@ -20,6 +19,5 @@ import (
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
_ "github.com/ollama/ollama/model/models/qwen3next"
_ "github.com/ollama/ollama/model/models/qwen3vl"
)

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
tokenizer := tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -219,8 +220,8 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,7 +45,7 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -58,14 +59,14 @@ func New(c fs.Config) (model.Model, error) {
),
}
processor := model.NewBytePairEncoding(
tokenizer := tokenizer.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,103 +0,0 @@
package qwen3next
import (
"errors"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the attention layer requirements.
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
// Key differences from standard attention:
// - Q projection outputs 2x size (Q + gate interleaved)
// - Both Q and K have RMSNorm
// - Output is gated: attn * sigmoid(gate)
type FullAttention struct {
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
// Use Dim() instead of Shape() for consistent behavior during graph construction
hiddenDim := hiddenStates.Dim(0)
batchSize := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
}
}
headDim := opts.headDim()
numHeads := opts.numHeads
// Q projection outputs query + gate interleaved
qFull := sa.Query.Forward(ctx, hiddenStates)
// Reshape to [headDim * 2, numHeads, batchSize]
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
// Split Q and gate along dimension 0
// Q: first headDim elements, gate: second headDim elements
query := qFull.Slice(ctx, 0, 0, headDim, 1)
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
// Make query contiguous for further operations
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
// K and V projections
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
// Derive numKVHeads from tensor dimensions (per-layer value)
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Apply QK normalization
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
// Apply RoPE
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Standard attention computation
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Apply sigmoid gate
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
gateSigmoid := gate.Sigmoid(ctx)
attention = attention.Mul(ctx, gateSigmoid)
return sa.Output.Forward(ctx, attention), nil
}

View File

@@ -1,596 +0,0 @@
package qwen3next
import (
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for full attention layers
// - per-sequence conv state for linear attention layers
// - per-sequence delta state for linear attention layers
//
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int // convKernelSize - 1
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
// Delta state dimensions
deltaStateSize int // headVDim * headVDim * numVHeads
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer delta state buffers (allocated lazily)
deltaCtxs map[int]ml.Context
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointDeltaCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewHybridCache(
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
convDim, convChannels, deltaStateSize int,
) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
convDim: convDim,
convChannels: convChannels,
deltaStateSize: deltaStateSize,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
deltaCtxs: make(map[int]ml.Context),
deltaStates: make(map[int]ml.Tensor),
checkpointCount: checkpointCountDefault,
checkpointMinPos: checkpointMinPosDefault,
checkpointInterval: checkpointIntervalDefault,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointDeltaCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.deltaCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointDeltaCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for recurrent layers
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
c.reserveCheckpoints = true
c.planCheckpoints(batch)
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero state for newly allocated slots
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = false
c.planCheckpoints(batch)
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros the recurrent state for the given slots across all layers.
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Zero conv states
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// Zero delta states
if len(c.deltaStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
for _, buf := range c.deltaStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
for _, buf := range c.deltaStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// Copy-on-write for recurrent state
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return kvcache.ErrNotSupported
}
if !c.restoreComplete(restore) {
return kvcache.ErrNotSupported
}
// If the recurrent slot is shared, detach it before applying a restore.
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
// Removal invalidates recurrent state
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *HybridCache) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.deltaStates[layer]; ok {
return buf
}
if _, ok := c.deltaCtxs[layer]; !ok {
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent delta state must stay in F32.
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
c.deltaStates[layer] = buf
return buf
}
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
}
// UpdateDeltaState writes a new delta state for current batch sequences.
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -1,498 +0,0 @@
package qwen3next
import (
"log/slog"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
checkpointCountDefault = 32
checkpointMinPosDefault = int32(16)
checkpointIntervalDefault = int32(1280)
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
delta map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *HybridCache) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return kvcache.ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.delta {
buf := c.deltaBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.delta) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.deltaStates {
if entry.delta == nil || entry.delta[layer] == nil {
return false
}
}
return true
}
func (c *HybridCache) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.delta != nil {
if dstEntry.delta == nil {
dstEntry.delta = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.delta {
dst := c.ensureCheckpointDelta(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
if entry.delta == nil {
entry.delta = make(map[int]ml.Tensor)
}
if t, ok := entry.delta[layer]; ok {
return t
}
ctx, ok := c.checkpointDeltaCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointDeltaCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
entry.delta[layer] = t
return t
}
func (c *HybridCache) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *HybridCache) reserveCheckpointDelta(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointDelta(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -1,300 +0,0 @@
package qwen3next
import (
"errors"
"math"
"os"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
func newTestBackend(tb testing.TB) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
if err != nil {
tb.Fatal(err)
}
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
_ = f.Close()
tb.Fatal(err)
}
if err := f.Close(); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
b.Close()
})
return b
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCachePrepareRestore(t *testing.T) {
cache := NewHybridCache(nil, 1, 1, 1)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
backend := newTestBackend(t)
cache := NewHybridCache(nil, 1, 2, 2)
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
cache.slotForSeq[1] = 0
cache.slotForSeq[2] = 0
cache.refCount[0] = 2
cache.refCount[1] = 0
cache.freeSlots = []int{1}
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
t.Fatalf("Remove failed: %v", err)
}
if cache.slotForSeq[1] == cache.slotForSeq[2] {
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
}
if cache.slotForSeq[1] != 1 {
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
}
if cache.slotForSeq[2] != 0 {
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
}
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared")
}
}
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
// Only set conv checkpoint, not delta - making it incomplete
entry.conv = map[int]ml.Tensor{0: nil}
// entry.delta is not set, so checkpoint is incomplete
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, kvcache.ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Don't set convStates/deltaStates - with no layers to check,
// entryComplete will return true as long as entry.pos >= 0
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
// Test that restoreComplete returns true when no layers need checkpoints
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
// Test that ring buffer wrap-around reuses entries without clearing maps.
store := newSlotCheckpointStore(3)
// Fill the buffer
store.record(10)
store.record(20)
store.record(30)
// Create fake tensor data in the first entry's maps
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil // Simulated tensor reference
store.entries[0].delta = make(map[int]ml.Tensor)
store.entries[0].delta[0] = nil // Simulated tensor reference
// Record another entry, which should wrap around and overwrite entry 0
store.record(40)
// Verify the maps are still present (we reuse tensors)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].delta == nil {
t.Fatalf("expected delta map to be preserved on reuse")
}
// Verify the new position was recorded
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
// Test behavior when buffer is exactly at capacity
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
// Verify both checkpoints are accessible
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
// Test behavior with zero-size buffer
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
// Test pruning that removes all checkpoints
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
// Prune everything by setting threshold below all positions
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
// When all checkpoints are pruned, lastPos is reset to -1
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -1,472 +0,0 @@
package qwen3next
import (
"errors"
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const chunkSize = 64
// TriType constants for triangular matrix operations
const (
TriTypeUpperDiag = 0
TriTypeUpper = 1
TriTypeLowerDiag = 2
TriTypeLower = 3
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Masks holds pre-computed mask tensors for chunked attention
type Masks struct {
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
Diag ml.Tensor // causal + identity
}
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
// It implements the Operator interface directly.
type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
// Layer index for cache access (set during model construction)
Layer int
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
ones = ones.Fill(ctx, 1.0)
causalMask := ones.Tri(ctx, TriTypeLower)
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
onesVec = onesVec.Fill(ctx, 1.0)
identity := onesVec.Diag(ctx)
diagMask := causalMask.Add(ctx, identity)
return &Masks{
Causal: causalMask,
Identity: identity,
Diag: diagMask,
}
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
}
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
// Split beta and alpha
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Reshape to merge head dimensions
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
alphaSoftplus := alphaBiased.Softplus(ctx)
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
}
// Reshape conv states
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
// Concatenate with input for convolution
convInput := convStates.Concat(ctx, qkvMixed, 0)
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
convOutput = convOutput.SILU(ctx)
// Reshape for extraction
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
// Extract convolved Q, K, V
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
// Reshape to 4D
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
}
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
// Repeat interleave Q and K if numKHeads != numVHeads
if numKHeads != numVHeads {
repeatFactor := numVHeads / numKHeads
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
}
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
zSilu := z2D.SILU(ctx)
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
// Reshape for output projection
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
out := gdn.SSMOut.Forward(ctx, finalOutput)
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}
// deltaNetAutoregressive implements single-token state update.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Reshape gate and beta for broadcasting
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
// Apply exponential to gate
gT = gT.Exp(ctx)
// state = state * g_t
state = state.Mul(ctx, gT)
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
kvMem := state.Mul(ctx, kTUnsqueezed)
// Sum over dim=-2 (second dimension after permute)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kvMem = kvMem.SumRows(ctx)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
// v_t with singleton dimension
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
// delta = (v_t - kv_mem) * beta_t
vDiff := vT.Sub(ctx, kvMem)
delta := vDiff.Mul(ctx, betaT)
// state = state + k_t.unsqueeze(-1) * delta
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
state = state.Add(ctx, kTDelta)
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
stateQ := state.Mul(ctx, qTUnsqueezed)
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
}
// deltaNetChunked implements chunked computation for prefill.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetChunked(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nTokens := q.Dim(2)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Permute tensors for chunked computation
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Compute padding
pad := (chunkSize - nTokens%chunkSize) % chunkSize
nChunks := (nTokens + pad) / chunkSize
// Pad tensors
if pad > 0 {
q = q.Pad(ctx, 0, pad, 0, 0)
k = k.Pad(ctx, 0, pad, 0, 0)
v = v.Pad(ctx, 0, pad, 0, 0)
gate = gate.Pad(ctx, pad, 0, 0, 0)
beta = beta.Pad(ctx, 0, pad, 0, 0)
}
// Use pre-computed masks (passed in, not recreated)
causalMask := masks.Causal
identity := masks.Identity
diagMask := masks.Diag
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
// v_beta = v * beta, k_beta = k * beta
vBeta := v.Mul(ctx, beta)
kBeta := k.Mul(ctx, beta)
// Reshape for chunked computation
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
// g_cumsum = cumsum(gate)
gCumsum := gate.CumSum(ctx)
// Compute decay mask
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
decayMask := gcsBroadcast.Sub(ctx, gcsI)
decayMask = decayMask.Mul(ctx, diagMask)
decayMask = decayMask.Exp(ctx)
decayMask = decayMask.Mul(ctx, diagMask)
// k @ k_beta^T
kMulKBeta := k.Mulmat(ctx, kBeta)
// k_decay = k @ k_beta^T * decay_mask
kDecay := kMulKBeta.Mul(ctx, decayMask)
// attn = -k_decay * causal_mask
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
// Triangular solve: (I - attn_lower)^-1 @ attn
attnLower := attn.Mul(ctx, causalMask)
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
attn = linSolve.Mul(ctx, causalMask)
attn = attn.Add(ctx, identity4D)
// v = v_beta^T @ attn
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
v = vBetaT.Mulmat(ctx, attn)
// Compute g_exp for state update
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
gExp := gCumsumT.Exp(ctx)
// kbeta_gexp = k_beta * g_exp
kBetaGExp := kBeta.Mul(ctx, gExp)
// k_cumdecay = attn @ kbeta_gexp^T
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
attnKQ := k.Mulmat(ctx, q)
attnKQ = attnKQ.Mul(ctx, decayMask)
attnKQ = attnKQ.Mul(ctx, diagMask)
// Pre-compute g_last and key_gdiff
// g_last = view of last element in g_cumsum along chunk_size dimension
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
gLastExp := gLast.Exp(ctx)
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
gDiffExp := gDiff.Exp(ctx)
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
newState := state
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
// state^T - permute is needed but Contiguous creates a copy
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// v_prime = k_cumdecay @ state
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
// v_new = v - v_prime
vNew := vChunk.Sub(ctx, vPrime)
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// attn_inter = (q * g_exp) @ state
qGExp := qChunk.Mul(ctx, gExpChunk)
attnInter := stateT.Mulmat(ctx, qGExp)
// core_attn_out = attn_inter + attn @ v_new
vAttn := vNewT.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
if coreAttnOut == nil {
coreAttnOut = coreAttnOutChunk
} else {
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// state = state * g_last + kgdmulvnew
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
newState = newState.Mul(ctx, gExpLastReshaped)
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
}
// Final reshape
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
// Slice to remove padding
if pad > 0 {
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
}

View File

@@ -1,383 +0,0 @@
package qwen3next
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int
numKVHeads int
keyLength int
valueLength int
ropeDim int
eps float32
ropeBase float32
ropeScale float32
ropeType string
originalContextLength int
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
normTopKProb bool
// Linear attention (Gated Delta Net) config
ssmDInner int // d_inner = head_v_dim * num_v_heads
ssmDState int // head_k_dim
ssmNGroup int // num_k_heads
ssmDtRank int // num_v_heads
convKernelSize int // SSM conv kernel size
// Per-layer type from GGUF metadata
isRecurrent []bool
// Pre-computed masks for chunked attention (created once per forward pass)
masks *Masks
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
ropeDim := cmp.Or(o.ropeDim, o.headDim())
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
}
// Operator is the interface for attention-like operators
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// sparse implements MoE with shared experts
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
// Shared experts
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
// Router logits
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
// Softmax routing weights
routingWeights := routerLogits.Softmax(ctx)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
// Expert computation with SILU activation
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
experts := gateOut.SILU(ctx, upOut)
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
// Add shared experts if present
if mlp.SharedUp != nil {
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
sharedOut := sharedGate.SILU(ctx, sharedUp)
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
// Apply shared expert gating
if mlp.SharedGateInp != nil {
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
sharedGateVal = sharedGateVal.Sigmoid(ctx)
// Broadcast gate to match dimensions
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
}
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}
// dense implements standard feedforward
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
Operator Operator
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-attention norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Attention (full or linear)
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
if err != nil {
return nil, err
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// First residual connection
hiddenStates = hiddenStates.Add(ctx, residual)
// Save for FFN residual
ffnResidual := hiddenStates
// Post-attention norm (before FFN)
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
// FFN
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
// Second residual connection
return hiddenStates.Add(ctx, ffnResidual), nil
}
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
// Create masks once per forward pass
m.Options.masks = createMasks(ctx)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer head counts (for detecting layer type)
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
layers[i].Operator = &GatedDeltaNet{Layer: i}
} else {
layers[i].Operator = &FullAttention{}
}
if isMoE {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: func() int {
for _, v := range headCountKV {
if v > 0 {
return int(v)
}
}
return 0
}(),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attentionScale: float64(c.Float("attention.scale")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
isRecurrent: isRecurrent,
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
}
// Calculate cache dimensions
convDim := max(0, opts.convKernelSize-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
headVDim := 0
numVHeads := opts.ssmDtRank
if numVHeads > 0 {
headVDim = opts.ssmDInner / numVHeads
}
deltaStateSize := headVDim * headVDim * numVHeads
// Validate dimension assumption: headKDim == headVDim is required for state computations
headKDim := opts.ssmDState
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
return &m, nil
}
func init() {
model.Register("qwen3next", New)
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,17 +0,0 @@
package parsers
import "github.com/ollama/ollama/api"
// GlmOcrParser is the GLM46 parser with thinking disabled.
type GlmOcrParser struct {
GLM46Parser
}
func (p *GlmOcrParser) HasThinkingSupport() bool {
return false
}
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
@@ -18,34 +17,12 @@ const (
ministralCollectingToolArgs
)
// ministralEvent represents an event emitted during parsing
type ministralEvent interface {
isMinistralEvent()
}
type ministralEventContent struct {
content string
}
type ministralEventThinking struct {
thinking string
}
type ministralEventToolCall struct {
name string
args string // raw JSON string
}
func (ministralEventContent) isMinistralEvent() {}
func (ministralEventThinking) isMinistralEvent() {}
func (ministralEventToolCall) isMinistralEvent() {}
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
pendingToolName string // stores tool name while collecting args
currentTool *api.Tool
}
func (p *MinistralParser) HasToolSupport() bool {
@@ -86,251 +63,74 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
return nil, fmt.Errorf("tool '%s' not found", n)
}
const (
ministralToolCallsTag = "[TOOL_CALLS]"
ministralThinkTag = "[THINK]"
ministralThinkEndTag = "[/THINK]"
ministralArgsTag = "[ARGS]"
)
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. The second return value indicates
// whether to keep looping (true when state transitions, false when waiting
// for more data).
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
var events []ministralEvent
switch p.state {
case ministralCollectingContent:
bufStr := p.buffer.String()
// Check for [TOOL_CALLS] tag
if strings.Contains(bufStr, ministralToolCallsTag) {
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolName
return events, true
}
// Check for [THINK] tag
if strings.Contains(bufStr, ministralThinkTag) {
split := strings.SplitN(bufStr, ministralThinkTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingThinkingContent
return events, true
}
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
overlapThink := overlap(bufStr, ministralThinkTag)
maxOverlap := max(overlapToolCalls, overlapThink)
if maxOverlap > 0 {
// Withhold the potential partial tag
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
trailingWS := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWS
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
}
// No tag found: emit content but withhold trailing whitespace
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
case ministralCollectingThinkingContent:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralThinkEndTag) {
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
thinkingContent := split[0]
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if len(thinkingContent) > 0 {
events = append(events, ministralEventThinking{thinking: thinkingContent})
}
p.state = ministralCollectingContent
return events, true
}
// Check for partial overlap with [/THINK]
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventThinking{thinking: unambiguous})
}
return events, false
}
// No tag found: emit all thinking content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, ministralEventThinking{thinking: bufStr})
}
return events, false
case ministralCollectingToolName:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralArgsTag) {
split := strings.SplitN(bufStr, ministralArgsTag, 2)
toolName := split[0]
after := split[1]
p.pendingToolName = toolName
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolArgs
return events, true
}
// Wait for more data
return events, false
case ministralCollectingToolArgs:
bufStr := p.buffer.String()
jsonEnd := findJSONEnd(bufStr)
if jsonEnd != -1 {
jsonStr := bufStr[:jsonEnd+1]
remaining := bufStr[jsonEnd+1:]
events = append(events, ministralEventToolCall{
name: p.pendingToolName,
args: jsonStr,
})
p.pendingToolName = ""
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = ministralCollectingContent
return events, true
}
// Wait for more data
return events, false
default:
panic("unexpected ministral event")
}
}
// parseEvents loops calling eat() until it returns false
func (p *MinistralParser) parseEvents() []ministralEvent {
var all []ministralEvent
keepLooping := true
for keepLooping {
var events []ministralEvent
events, keepLooping = p.eat()
all = append(all, events...)
}
return all
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentBuilder, thinkingBuilder strings.Builder
var toolCalls []api.ToolCall
for _, event := range events {
switch e := event.(type) {
case ministralEventContent:
contentBuilder.WriteString(e.content)
case ministralEventThinking:
thinkingBuilder.WriteString(e.thinking)
case ministralEventToolCall:
// Validate tool exists
tool, toolErr := toolByName(p.tools, e.name)
if toolErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
}
// Parse JSON arguments
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
}
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var args api.ToolCallFunctionArguments
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
if err := json.Unmarshal([]byte(before), &args); err != nil {
// todo - throw a better error
return "", "", calls, err
}
toolCalls = append(toolCalls, api.ToolCall{
p.state = ministralCollectingContent
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: tool.Function.Name,
Name: p.currentTool.Function.Name,
Arguments: args,
},
})
}
calls = append(calls, call)
return "", "", calls, nil
}
return "", "", calls, nil
}
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
}
// findJSONEnd finds the index of the closing brace that completes a JSON object.
// It properly handles nested objects, arrays, and strings (including escaped characters).
// Returns -1 if the JSON is not yet complete.
func findJSONEnd(s string) int {
depth := 0
inString := false
escaped := false
for i, r := range s {
if inString {
switch {
case escaped:
// If the previous character was a backslash, skip this character
escaped = false
case r == '\\':
// Mark the next character as escaped
escaped = true
case r == '"':
// End of string literal
inString = false
}
continue
}
switch r {
case '"':
// Start of string literal
inString = true
case '{', '[':
// Increase nesting level for objects and arrays
depth++
case '}', ']':
// Decrease nesting level
depth--
if depth == 0 {
// Reached the end of the root JSON structure
return i
}
}
}
return -1
return p.buffer.String(), thinking, calls, nil
}

View File

@@ -1,545 +0,0 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestMinistralParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []ministralEvent
}
cases := []struct {
desc string
tools []api.Tool
steps []step
think bool // whether to enable thinking support
}{
// Content streaming
{
desc: "simple content",
steps: []step{
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
ministralEventContent{content: "Hello, how can I help you?"},
}},
},
},
{
desc: "streaming content word by word",
steps: []step{
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
},
},
// Simple tool calls
{
desc: "simple tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
}},
},
},
{
desc: "tool call with nested object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "tool call with deeply nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
steps: []step{
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
}},
},
},
{
desc: "tool call with array of objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
steps: []step{
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
}},
},
},
{
desc: "tool call with escaped quotes in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
steps: []step{
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
}},
},
},
{
desc: "tool call with braces inside string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
steps: []step{
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
}},
},
},
{
desc: "empty JSON object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
steps: []step{
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "no_args", args: `{}`},
}},
},
},
{
desc: "JSON with newlines in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
steps: []step{
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
}},
},
},
{
desc: "backslash in string value",
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
steps: []step{
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
}},
},
},
// Content after tool call
{
desc: "content after tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// NOTE: It's unclear if this is valid Ministral output, but the parser
// currently treats text after a tool call as regular content. This test
// documents that behavior so we notice if it changes.
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": 1}`},
ministralEventContent{content: "some content after"},
}},
},
},
// Multiple tool calls
{
desc: "multiple tool calls in sequence",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "get_weather"}},
{Function: api.ToolFunction{Name: "get_time"}},
},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
}},
},
},
{
desc: "multiple tool calls streamed separately",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "tool_a"}},
{Function: api.ToolFunction{Name: "tool_b"}},
},
steps: []step{
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
}},
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
}},
},
},
// Streaming tool calls
{
desc: "streaming tool call with nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
{input: `]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "streaming with incomplete JSON waits for completion",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
{input: `"a": {`, wantEvents: []ministralEvent{}},
{input: `"b": 1`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
}},
},
},
// Partial tag handling
{
desc: "partial tool tag fakeout",
steps: []step{
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
},
},
{
desc: "tool call tag split across chunks",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_", wantEvents: []ministralEvent{}},
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "content before tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "hello"},
ministralEventToolCall{name: "get_weather", args: `{}`},
}},
},
},
{
desc: "whitespace between content and tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "tabs and newlines before tool call are trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "non-breaking space before tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "whitespace before THINK tag is trimmed",
steps: []step{
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "after"},
}},
},
},
{
desc: "trailing whitespace withheld then emitted",
steps: []step{
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
},
},
{
desc: "trailing newline withheld then emitted",
steps: []step{
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
},
},
// Thinking support
{
desc: "thinking content",
think: true,
steps: []step{
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking here"},
}},
{input: "content after", wantEvents: []ministralEvent{
ministralEventContent{content: "content after"},
}},
},
},
{
desc: "thinking with whitespace after end tag",
think: true,
steps: []step{
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "my thoughts"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "non-breaking space after think end tag is trimmed",
think: true,
steps: []step{
// \u00a0 is non-breaking space
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "partial think end tag",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
},
},
{
desc: "think tag fakeout",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
},
},
{
desc: "thinking then tool call",
think: true,
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "let me think"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
// Content then THINK tag transition
{
desc: "content then think tag",
steps: []step{
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "more"},
}},
},
},
// Unicode handling
{
desc: "unicode content",
steps: []step{
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
ministralEventContent{content: "你好 🌍 مرحبا"},
}},
},
},
{
desc: "unicode in tool args",
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
steps: []step{
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
}},
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := MinistralParser{}
parser.hasThinkingSupport = tc.think
parser.Init(tc.tools, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestMinistralParser_Errors(t *testing.T) {
t.Run("unknown tool returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
if err == nil {
t.Fatal("expected error for unknown tool")
}
})
t.Run("invalid JSON returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
})
}
func TestFindJSONEnd(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "simple object",
input: `{"a": 1}`,
expected: 7,
},
{
name: "nested object",
input: `{"a": {"b": 2}}`,
expected: 14,
},
{
name: "array inside object",
input: `{"items": [1, 2, 3]}`,
expected: 19,
},
{
name: "braces in string",
input: `{"template": "Hello {name}!"}`,
expected: 28,
},
{
name: "escaped quotes",
input: `{"msg": "say \"hi\""}`,
expected: 20,
},
{
name: "incomplete object",
input: `{"a": {"b": 1}`,
expected: -1,
},
{
name: "deeply nested",
input: `{"a": {"b": {"c": {"d": 1}}}}`,
expected: 28,
},
{
name: "object with trailing content",
input: `{"a": 1} extra`,
expected: 7,
},
{
name: "array",
input: `[{"a": 1}, {"b": 2}]`,
expected: 19,
},
{
name: "escaped backslash before quote",
input: `{"path": "C:\\"}`,
expected: 15,
},
{
name: "empty string",
input: "",
expected: -1,
},
{
name: "no opening brace",
input: "hello world",
expected: -1,
},
{
name: "only opening brace",
input: "{",
expected: -1,
},
{
name: "unclosed string",
input: `{"key": "unclosed`,
expected: -1,
},
{
name: "double escaped backslash then quote",
input: `{"path": "C:\\\\"}`,
expected: 17,
},
{
name: "unicode in key and value",
input: `{"키": "값"}`,
expected: 13,
},
{
name: "nested arrays",
input: `{"matrix": [[1, 2], [3, 4]]}`,
expected: 27,
},
{
name: "mixed nesting",
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
expected: 31,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findJSONEnd(tt.input)
if result != tt.expected {
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
}
})
}
}
func TestMinistralParser_HasToolSupport(t *testing.T) {
p := &MinistralParser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
p := &MinistralParser{hasThinkingSupport: false}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
p = &MinistralParser{hasThinkingSupport: true}
if !p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return true")
}
}

View File

@@ -3,7 +3,6 @@ package parsers
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
@@ -71,8 +70,6 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "glm-ocr":
return &GlmOcrParser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
@@ -117,33 +114,3 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
sb.WriteString(after)
return before, after // return events
}
// overlap returns the longest overlap between the suffix of s and the prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}

View File

@@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
@@ -193,6 +194,36 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`

View File

@@ -1,109 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GlmOcrRenderer struct{}
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
enableThinking := false
thinkingExplicitlySet := false
if thinkValue != nil {
enableThinking = thinkValue.Bool()
thinkingExplicitlySet = true
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>\n")
if message.Thinking != "" {
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
} else {
sb.WriteString("<think></think>")
}
if message.Content != "" {
sb.WriteString("\n" + strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
sb.WriteString("\n")
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
sb.WriteString("\n")
}
}
sb.WriteString("<|assistant|>\n")
if thinkingExplicitlySet && !enableThinking {
sb.WriteString("<think></think>\n")
}
return sb.String(), nil
}
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}

View File

@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user")
sb.WriteString(imStartTag + "user\n")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString("<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
sb.WriteString("\n</tool_response>\n")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {

View File

@@ -1,7 +1,6 @@
package renderers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -128,7 +127,8 @@ fahrenheit
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response><|im_end|>
</tool_response>
<|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
@@ -233,7 +233,8 @@ I'll call double(1) and triple(2) for you.
</tool_response>
<tool_response>
{"number": 6}
</tool_response><|im_end|>
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
@@ -279,7 +280,8 @@ call tool<|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response><|im_end|>
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
@@ -335,31 +337,6 @@ func TestFormatToolCallArgument(t *testing.T) {
}
}
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"payload": "ok"}),
}},
}},
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
}
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
}
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
}
}
func TestQwen3ToolDefinitionTypes(t *testing.T) {
tests := []struct {
name string

View File

@@ -82,8 +82,6 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":

View File

@@ -124,17 +124,8 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
}
if c.cache != nil {
if numPast > 0 {
// Recurrent caches use checkpoints to pick a safe resume position.
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
numPast = restored
} else {
numPast = 0
}
} else if !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tokenizer tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tokenizer.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -740,11 +741,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
// If the sequence was replaced while this batch was computing, discard results.
if activeBatch.seqs[i] != seq {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
@@ -770,7 +767,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -779,14 +776,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -877,7 +874,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
@@ -1362,7 +1359,7 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
// Dummy load to get the backend wired up
f, err := os.CreateTemp("", "*.bin")
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
return
}
defer f.Close()
@@ -1372,13 +1369,13 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
return
}
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
return
}
slog.Debug("dummy model load took", "duration", time.Since(startLoad))

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/mlxrunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -12,18 +12,18 @@ func Execute(args []string) error {
}
var newRunner bool
var mlxRunner bool
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
mlxRunner = true
imageRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tokenizer tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tokenizer.Vocabulary().Values))
pieces := make([]string, len(tokenizer.Vocabulary().Values))
for i := range tokenizer.Vocabulary().Values {
pieces[i], _ = tokenizer.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tokenizer.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.Join("..", "testdata", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -27,12 +27,14 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens := 768
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
// always include the last message
if i == n {
continue
}
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
@@ -52,26 +54,20 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
for _, m := range msgs[i:] {
ctxLen += imageNumTokens * len(m.Images)
}
}
if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
if truncate && ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
n = i
}
}
if currMsgIdx > 0 {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
}
currMsgIdx := n
for cnt, msg := range msgs[currMsgIdx:] {
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {

View File

@@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"testing"
"github.com/google/go-cmp/cmp"
@@ -265,68 +264,3 @@ func TestChatPrompt(t *testing.T) {
})
}
}
func TestChatPromptTokenizeCalls(t *testing.T) {
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
model := Model{Template: tmpl}
cases := []struct {
name string
limit int
msgs []api.Message
maxTokenizes int
}{
{
name: "all messages fit",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 1,
},
{
name: "truncate to last message",
limit: 5,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 5,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizeCount := 0
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
tokenizeCount++
tokens, err := mockRunner{}.Tokenize(ctx, s)
return tokens, err
}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if tokenizeCount > tt.maxTokenizes {
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More