mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 13:13:34 -05:00
Compare commits
36 Commits
mxyng/mlx
...
parth-laun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52f757d8a2 | ||
|
|
86aa7cd0a6 | ||
|
|
255579aaa7 | ||
|
|
f7102ba826 | ||
|
|
cefabd79a8 | ||
|
|
df70249520 | ||
|
|
77eb2ca619 | ||
|
|
ee25219edd | ||
|
|
b1fccabb34 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 | ||
|
|
8f4a008139 | ||
|
|
d8cc798c2b | ||
|
|
6582f6da5c | ||
|
|
0334ffa625 | ||
|
|
d11fbd2c60 | ||
|
|
6a7c3f188e | ||
|
|
427e2c962a | ||
|
|
27db7f806f | ||
|
|
3590fbfa76 | ||
|
|
cd0094f772 | ||
|
|
06bc8e6712 | ||
|
|
fc5f9bb448 | ||
|
|
a0740f7ef7 | ||
|
|
a0923cbdd0 | ||
|
|
f92e362b2e | ||
|
|
aa23d8ecd2 | ||
|
|
7b62c41060 | ||
|
|
26acab64b7 | ||
|
|
e0f03790b1 | ||
|
|
3ab842b0f5 | ||
|
|
b8e8ef8929 |
@@ -358,6 +358,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
@@ -465,6 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
3
anthropic/anthropic.go
Normal file → Executable file
3
anthropic/anthropic.go
Normal file → Executable file
@@ -211,6 +211,7 @@ type MessageDelta struct {
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -721,6 +722,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
})
|
||||
}
|
||||
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
@@ -732,6 +734,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
|
||||
20
anthropic/anthropic_test.go
Normal file → Executable file
20
anthropic/anthropic_test.go
Normal file → Executable file
@@ -642,7 +642,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
@@ -650,6 +650,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_delta" {
|
||||
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||
if data.Type != "message_delta" {
|
||||
t.Errorf("unexpected data type: %+v", data)
|
||||
}
|
||||
|
||||
if data.Delta.StopReason != "end_turn" {
|
||||
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||
}
|
||||
|
||||
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("unexpected data: %+v", e.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -52,7 +53,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -663,6 +664,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
_ = browser.OpenURL(aErr.SigninURL)
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
@@ -1888,7 +1890,7 @@ func NewCLI() *cobra.Command {
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Aliases: []string{"start"},
|
||||
Short: "Start ollama",
|
||||
Short: "Start Ollama",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: RunServer,
|
||||
}
|
||||
|
||||
@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
@@ -13,11 +15,13 @@ type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
return nil
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
@@ -39,18 +43,18 @@ func (c *Claude) findPath() (string, error) {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
@@ -84,17 +84,21 @@ func TestClaudeArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,20 +14,21 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -20,6 +21,14 @@ type config struct {
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,6 +36,46 @@ func configPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
slog.Info("migrated config", "from", oldPath, "to", newPath)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
@@ -34,6 +83,11 @@ func load() (*config, error) {
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
|
||||
@@ -200,12 +200,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
@@ -267,7 +265,7 @@ func TestConfigPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
@@ -322,6 +320,183 @@ func TestLoad(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
@@ -37,7 +39,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -51,7 +53,7 @@ func (d *Droid) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -117,7 +119,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
|
||||
{
|
||||
"model": "existing-ollama-model",
|
||||
"displayName": "existing-ollama-model",
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"apiKey": "ollama",
|
||||
"provider": "generic-chat-completion-api",
|
||||
"maxOutputTokens": 64000,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
Run(model string, args []string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
@@ -41,9 +42,29 @@ type Editor interface {
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
"openclaw": &Openclaw{},
|
||||
"pi": &Pi{},
|
||||
}
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
|
||||
}
|
||||
|
||||
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
@@ -54,6 +75,9 @@ func selectIntegration() (string, error) {
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
@@ -82,62 +106,25 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
@@ -151,7 +138,27 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
if len(toPull) > 0 {
|
||||
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errCancelled
|
||||
}
|
||||
for _, m := range toPull {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, m); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
@@ -221,13 +228,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
@@ -236,7 +243,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
@@ -245,19 +252,43 @@ Supported integrations:
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
// No "--" separator: only allow 0 or 1 args (integration name)
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
// "--" was used: args before it = integration name, args after = passthrough
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
@@ -273,16 +304,14 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
@@ -337,13 +366,13 @@ Examples:
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -351,3 +380,154 @@ Examples:
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModel(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Non-existing models get "install?" suffix and are pushed to the bottom.
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
if items[i].Description != "" {
|
||||
items[i].Description += ", install?"
|
||||
} else {
|
||||
items[i].Description = "install?"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if !ac && !bc && aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -90,8 +92,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -121,7 +123,7 @@ func TestLaunchCmd(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -174,15 +176,336 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
var _ func(string, []string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseArgs(t *testing.T) {
|
||||
// Tests reflect cobra's ArgsLenAtDash() semantics:
|
||||
// - cobra strips "--" from args
|
||||
// - ArgsLenAtDash() returns the index where "--" was, or -1
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string // args as cobra delivers them (no "--")
|
||||
dashIdx int // what ArgsLenAtDash() returns
|
||||
wantName string
|
||||
wantArgs []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no extra args, no dash",
|
||||
args: []string{"claude"},
|
||||
dashIdx: -1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "with extra args after --",
|
||||
args: []string{"codex", "-p", "myprofile"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"-p", "myprofile"},
|
||||
},
|
||||
{
|
||||
name: "extra args only after --",
|
||||
args: []string{"codex", "--sandbox", "workspace-write"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"--sandbox", "workspace-write"},
|
||||
},
|
||||
{
|
||||
name: "-- at end with no args after",
|
||||
args: []string{"claude"},
|
||||
dashIdx: 1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "-- with no integration name",
|
||||
args: []string{"--verbose"},
|
||||
dashIdx: 0,
|
||||
wantName: "",
|
||||
wantArgs: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple args before -- is error",
|
||||
args: []string{"claude", "codex", "--verbose"},
|
||||
dashIdx: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple args without -- is error",
|
||||
args: []string{"claude", "codex"},
|
||||
dashIdx: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no args, no dash",
|
||||
args: []string{},
|
||||
dashIdx: -1,
|
||||
wantName: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the parsing logic from LaunchCmd using dashIdx
|
||||
var name string
|
||||
var parsedArgs []string
|
||||
var err error
|
||||
|
||||
dashIdx := tt.dashIdx
|
||||
args := tt.args
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
err = fmt.Errorf("unexpected arguments: %v", args[1:])
|
||||
} else if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
} else {
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
parsedArgs = args[dashIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if name != tt.wantName {
|
||||
t.Errorf("name = %q, want %q", name, tt.wantName)
|
||||
}
|
||||
if !slices.Equal(parsedArgs, tt.wantArgs) {
|
||||
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
t.Errorf("pre-checked model should be first, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "glm-4.7:cloud":
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "kimi-k2.5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_LatestTagStripped(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash:latest", Remote: false},
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// :latest should be stripped from display names
|
||||
for _, name := range got {
|
||||
if strings.HasSuffix(name, ":latest") {
|
||||
t.Errorf("name %q should not have :latest suffix", name)
|
||||
}
|
||||
}
|
||||
|
||||
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
|
||||
count := 0
|
||||
for _, name := range got {
|
||||
if name == "glm-4.7-flash" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
|
||||
}
|
||||
|
||||
// Stripped name should be in existingModels so it won't be pulled
|
||||
if !existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should be in existingModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if !existingModels["llama3.2"] {
|
||||
t.Error("llama3.2 should be in existingModels")
|
||||
}
|
||||
if !existingModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in existingModels")
|
||||
}
|
||||
if existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
|
||||
}
|
||||
|
||||
if !cloudModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in cloudModels")
|
||||
}
|
||||
if !cloudModels["kimi-k2.5:cloud"] {
|
||||
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
|
||||
}
|
||||
if cloudModels["llama3.2"] {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
254
cmd/config/openclaw.go
Normal file
254
cmd/config/openclaw.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return []string{legacy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
878
cmd/config/openclaw_test.go
Normal file
878
cmd/config/openclaw_test.go
Normal file
@@ -0,0 +1,878 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "OpenClaw" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenClaw")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "mistral")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["mcp"] == nil {
|
||||
t.Error("mcp was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
// User adds custom field
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
entry["customField"] = "user-value"
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
models = cfg["models"].(map[string]any)
|
||||
providers = models["providers"].(map[string]any)
|
||||
ollama = providers["ollama"].(map[string]any)
|
||||
modelList = ollama["models"].([]any)
|
||||
entry = modelList[0].(map[string]any)
|
||||
if entry["customField"] != "user-value" {
|
||||
t.Error("custom field was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("edit replaces models list", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
original := `{"existing":"data"}`
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
c.Edit([]string{})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != original {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Error("result should be valid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type models section", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModels(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("no config returns nil", func(t *testing.T) {
|
||||
if models := c.Models(); len(models) > 0 {
|
||||
t.Errorf("expected nil/empty, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %v", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertOpenclawModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models, _ := cfg["models"].(map[string]any)
|
||||
providers, _ := models["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
t.Errorf("model %s should not exist", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
model := defaults["model"].(map[string]any)
|
||||
if model["primary"] != expected {
|
||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModelsEdgeCases(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at models level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model entry missing id", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for missing id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model id is not string", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for non-string id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
}
|
||||
if cost["cacheWrite"] == nil {
|
||||
t.Error("cost.cacheWrite should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEditModelNames(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "library/model:tag")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEditAgentsPreservation(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature setting was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent was lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const testOpenclawFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
||||
"custom-agent": {"foo": "bar"}
|
||||
}
|
||||
}`
|
||||
|
||||
func TestOpenclawEdit_RoundTrip(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
// Verify top-level preserved
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme not preserved")
|
||||
}
|
||||
mcp := cfg["mcp"].(map[string]any)
|
||||
servers := mcp["servers"].(map[string]any)
|
||||
if servers["custom"] == nil {
|
||||
t.Error("mcp.servers.custom not preserved")
|
||||
}
|
||||
|
||||
// Verify other providers preserved
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider not preserved")
|
||||
}
|
||||
|
||||
// Verify agents preserved
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent not preserved")
|
||||
}
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_Idempotent(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
secondData, _ := os.ReadFile(configPath)
|
||||
|
||||
if string(firstData) != string(secondData) {
|
||||
t.Error("repeated edits with same models produced different results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
if i%2 == 0 {
|
||||
models = []string{"model-x", "model-y", "model-z"}
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
t.Fatalf("edit %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
||||
}
|
||||
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme lost after multiple edits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_BackupCreated(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
if string(data) == original {
|
||||
foundBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup with original content not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawClawdbotAlias(t *testing.T) {
|
||||
for _, alias := range []string{"clawdbot", "moltbot"} {
|
||||
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
|
||||
r, ok := integrations[alias]
|
||||
if !ok {
|
||||
t.Fatalf("%s not found in integrations", alias)
|
||||
}
|
||||
if _, ok := r.(*Openclaw); !ok {
|
||||
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(alias+" is hidden from selector", func(t *testing.T) {
|
||||
if !integrationAliases[alias] {
|
||||
t.Errorf("%s should be in integrationAliases", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawLegacyPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
|
||||
t.Errorf("expected legacy path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(newDir, "openclaw.json") {
|
||||
t.Errorf("expected new path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models reads from legacy path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "llama3.2" {
|
||||
t.Errorf("expected [llama3.2], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
|
||||
}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "new-model" {
|
||||
t.Errorf("expected [new-model], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "new" {
|
||||
t.Errorf("expected theme from new config, got %v", cfg["theme"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit migrates from legacy config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should write to new path
|
||||
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatal("expected new config file to be created")
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("legacy theme setting was not migrated")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
}
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawOnboarded(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns false when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no config exists")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no wizard section")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when wizard.lastRunAt is set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checks legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when legacy config has wizard.lastRunAt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
// New path has no wizard marker
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
// Legacy has wizard marker
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false - should prefer new path which has no wizard marker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false for corrupted JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles wrong type for wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard is wrong type")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -16,7 +18,7 @@ type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -30,7 +32,7 @@ func (o *OpenCode) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -88,7 +90,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
196
cmd/config/pi.go
Normal file
196
cmd/config/pi.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if _, err := os.Stat(modelsPath); err == nil {
|
||||
paths = append(paths, modelsPath)
|
||||
}
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
if _, err := os.Stat(settingsPath); err == nil {
|
||||
paths = append(paths, settingsPath)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (p *Pi) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
providers, ok := config["providers"].(map[string]any)
|
||||
if !ok {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"baseUrl": envconfig.Host().String() + "/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
}
|
||||
}
|
||||
|
||||
existingModels, ok := ollama["models"].([]any)
|
||||
if !ok {
|
||||
existingModels = make([]any, 0)
|
||||
}
|
||||
|
||||
// Build set of selected models to track which need to be added
|
||||
selectedSet := make(map[string]bool, len(models))
|
||||
for _, m := range models {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
// User-managed model (no _launch marker) - always preserve
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Ollama-managed and still selected - keep it
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add newly selected models that weren't already in the list
|
||||
for _, model := range models {
|
||||
if selectedSet[model] {
|
||||
newModels = append(newModels, map[string]any{
|
||||
"id": model,
|
||||
"_launch": true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = newModels
|
||||
providers["ollama"] = ollama
|
||||
config["providers"] = providers
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update settings.json with default provider and model
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
_ = json.Unmarshal(data, &settings)
|
||||
}
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0]
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers, _ := config["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range models {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||
func isPiOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
609
cmd/config/pi_test.go
Normal file
609
cmd/config/pi_test.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPiIntegration(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := pi.String(); got != "Pi" {
|
||||
t.Errorf("String() = %q, want %q", got, "Pi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = pi
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = pi
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiPaths(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns empty when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("Paths() = %v, want empty", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiEdit(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
}
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("returns nil for empty models", func(t *testing.T) {
|
||||
if err := pi.Edit([]string{}); err != nil {
|
||||
t.Errorf("Edit([]) error = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates config with models", func(t *testing.T) {
|
||||
cleanup()
|
||||
|
||||
models := []string{"llama3.2", "qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
|
||||
providers, ok := cfg["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Config missing providers")
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Providers missing ollama")
|
||||
}
|
||||
|
||||
modelsArray, ok := ollama["models"].([]any)
|
||||
if !ok || len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %v", modelsArray)
|
||||
}
|
||||
|
||||
if ollama["baseUrl"] == nil {
|
||||
t.Error("Missing baseUrl")
|
||||
}
|
||||
if ollama["api"] != "openai-completions" {
|
||||
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "ollama" {
|
||||
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://custom:8080/v1",
|
||||
"api": "custom-api",
|
||||
"apiKey": "custom-key",
|
||||
"models": [
|
||||
{"id": "old-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"new-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
|
||||
if ollama["baseUrl"] != "http://custom:8080/v1" {
|
||||
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
|
||||
}
|
||||
if ollama["api"] != "custom-api" {
|
||||
t.Errorf("Custom api not preserved, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "custom-key" {
|
||||
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
|
||||
}
|
||||
|
||||
modelsArray := ollama["models"].([]any)
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
|
||||
} else {
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
if modelEntry["id"] != "new-model" {
|
||||
t.Errorf("Expected new-model, got %v", modelEntry["id"])
|
||||
}
|
||||
// Verify _launch marker is present
|
||||
if modelEntry["_launch"] != true {
|
||||
t.Errorf("Expected _launch marker to be true")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Old models must have _launch marker to be managed by us
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "old-model-1", "_launch": true},
|
||||
{"id": "old-model-2", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"new-model-1", "new-model-2"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
|
||||
t.Errorf("Expected new models, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
|
||||
t.Errorf("Old models should have been removed, got %v", modelIDs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles partial overlap in model list", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Models must have _launch marker to be managed
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "keep-model", "_launch": true},
|
||||
{"id": "remove-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"keep-model", "add-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
|
||||
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["remove-model"] {
|
||||
t.Errorf("remove-model should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model, got %d", len(modelsArray))
|
||||
}
|
||||
})
|
||||
|
||||
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
|
||||
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// User has manually configured models in ollama provider (no _launch marker)
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "user-model-1"},
|
||||
{"id": "user-model-2", "customField": "preserved"},
|
||||
{"id": "ollama-managed", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a new ollama-managed model
|
||||
newModels := []string{"new-ollama-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
// Should have: new-ollama-model (managed) + 2 user models (preserved)
|
||||
if len(modelsArray) != 3 {
|
||||
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]map[string]any)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = modelObj
|
||||
}
|
||||
|
||||
// Verify new model has _launch marker
|
||||
if m, ok := modelIDs["new-ollama-model"]; !ok {
|
||||
t.Errorf("new-ollama-model should be present")
|
||||
} else if m["_launch"] != true {
|
||||
t.Errorf("new-ollama-model should have _launch marker")
|
||||
}
|
||||
|
||||
// Verify user models are preserved
|
||||
if _, ok := modelIDs["user-model-1"]; !ok {
|
||||
t.Errorf("user-model-1 should be preserved")
|
||||
}
|
||||
if _, ok := modelIDs["user-model-2"]; !ok {
|
||||
t.Errorf("user-model-2 should be preserved")
|
||||
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
|
||||
t.Errorf("user-model-2 customField should be preserved")
|
||||
}
|
||||
|
||||
// Verify old ollama-managed model is removed (not in new list)
|
||||
if _, ok := modelIDs["ollama-managed"]; ok {
|
||||
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create existing settings with other fields
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
existingSettings := `{
|
||||
"theme": "dark",
|
||||
"customSetting": "value",
|
||||
"defaultProvider": "anthropic",
|
||||
"defaultModel": "claude-3"
|
||||
}`
|
||||
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"llama3.2"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
// Verify defaultProvider is set to ollama
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
|
||||
// Verify defaultModel is set to first model
|
||||
if settings["defaultModel"] != "llama3.2" {
|
||||
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||
}
|
||||
|
||||
// Verify other fields are preserved
|
||||
if settings["theme"] != "dark" {
|
||||
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||
}
|
||||
if settings["customSetting"] != "value" {
|
||||
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
models := []string{"qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("settings.json should be created: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "qwen3:8b" {
|
||||
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create corrupt settings
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "test-model" {
|
||||
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiModels(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns models from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "llama3.2"},
|
||||
{"id": "qwen3:8b"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("Models() returned %d models, want 2", len(models))
|
||||
}
|
||||
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
|
||||
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns sorted models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "z-model"},
|
||||
{"id": "a-model"},
|
||||
{"id": "m-model"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
|
||||
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when models array is missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil when models array is missing", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil for corrupt config", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -275,7 +275,11 @@ func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
@@ -314,7 +318,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
@@ -345,10 +353,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
desc := ""
|
||||
if item.Description != "" {
|
||||
desc = " " + ansiGray + "- " + item.Description + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
@@ -313,8 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
455
convert/convert_glmocr.go
Normal file
455
convert/convert_glmocr.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
|
||||
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
|
||||
//
|
||||
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
|
||||
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
|
||||
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
rotaryDim := int(float32(headDim) * partialRotaryFactor)
|
||||
if rotaryDim%2 != 0 {
|
||||
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
|
||||
}
|
||||
|
||||
// Handle 1D (bias) or 2D (weight) tensors
|
||||
is1D := len(shape) == 1
|
||||
var inFeatures int
|
||||
if is1D {
|
||||
inFeatures = 1
|
||||
} else {
|
||||
inFeatures = int(shape[1])
|
||||
}
|
||||
outFeatures := int(shape[0])
|
||||
nEffectiveHeads := outFeatures / headDim
|
||||
|
||||
if nEffectiveHeads != nHeads {
|
||||
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
|
||||
}
|
||||
|
||||
// Reshape to [n_heads, head_dim, in_features]
|
||||
reshaped := make([]float32, len(data))
|
||||
copy(reshaped, data)
|
||||
|
||||
// Permute the rotary dimensions: even indices first, then odd
|
||||
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
|
||||
result := make([]float32, len(data))
|
||||
halfRotary := rotaryDim / 2
|
||||
|
||||
for h := range nEffectiveHeads {
|
||||
for f := range inFeatures {
|
||||
for i := range halfRotary {
|
||||
// Even dim (0, 2, 4, ...) -> position i
|
||||
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
|
||||
dstIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
|
||||
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
|
||||
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
|
||||
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
}
|
||||
|
||||
// Non-rotary part: copy as-is
|
||||
for i := rotaryDim; i < headDim; i++ {
|
||||
srcIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[srcIdx] = reshaped[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
type glmOcrModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeParameters struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
MRopeSection []int32 `json:"mrope_section"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
Depth uint32 `json:"depth"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
OutHiddenSize uint32 `json:"out_hidden_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
ImageStartTokenID uint32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||
VideoStartTokenID uint32 `json:"video_start_token_id"`
|
||||
VideoEndTokenID uint32 `json:"video_end_token_id"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VideoTokenID uint32 `json:"video_token_id"`
|
||||
|
||||
// Preprocessor config (preprocessor_config.json)
|
||||
Preprocessor struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
} `json:"-"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*glmOcrModel)(nil)
|
||||
|
||||
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &m.Preprocessor)
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glmocr"
|
||||
|
||||
// Text model parameters
|
||||
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
|
||||
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
|
||||
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
|
||||
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
|
||||
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
|
||||
kv["glmocr.attention.key_length"] = headDim
|
||||
kv["glmocr.attention.value_length"] = headDim
|
||||
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
|
||||
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
|
||||
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
|
||||
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
|
||||
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
|
||||
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
|
||||
}
|
||||
|
||||
// Vision model parameters
|
||||
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
|
||||
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
|
||||
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
|
||||
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
|
||||
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
|
||||
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
|
||||
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
|
||||
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
|
||||
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
|
||||
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
|
||||
|
||||
// Preprocessor-derived image settings (min/max pixels and normalization)
|
||||
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
|
||||
if m.Preprocessor.Size.ShortestEdge > 0 {
|
||||
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
|
||||
}
|
||||
if m.Preprocessor.Size.LongestEdge > 0 {
|
||||
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
|
||||
}
|
||||
if len(m.Preprocessor.ImageMean) == 3 {
|
||||
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
|
||||
}
|
||||
if len(m.Preprocessor.ImageStd) == 3 {
|
||||
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
kv["glmocr.image_token_id"] = m.ImageTokenID
|
||||
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
|
||||
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
|
||||
kv["glmocr.video_token_id"] = m.VideoTokenID
|
||||
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
|
||||
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
|
||||
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
|
||||
skipLayer := func(name string) bool {
|
||||
// Tensor names are already replaced to "blk.N.xxx" format
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(name)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return blkNum >= numLayers
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip next-n prediction layers (layers >= num_hidden_layers)
|
||||
if skipLayer(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split ffn_gate_up into separate gate and up projections
|
||||
if strings.Contains(name, "ffn_gate_up") {
|
||||
for t := range splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
|
||||
) {
|
||||
out = append(out, t)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, "patch_embd.weight") {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 5 && shape[2] == 2 {
|
||||
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
|
||||
|
||||
t0 := t.Clone()
|
||||
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t0,
|
||||
})
|
||||
|
||||
t1 := t.Clone()
|
||||
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t1,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if len(shape) == 4 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
|
||||
// Fall through to default handling
|
||||
}
|
||||
|
||||
// Handle pre-split patch embedding weights
|
||||
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
|
||||
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
|
||||
if strings.Contains(name, "patch_embd.0.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.Contains(name, "patch_embd.1.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle .weight.0 and .weight.1 suffix patterns
|
||||
if strings.HasSuffix(name, "patch_embd.weight.0") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "patch_embd.weight.1") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
|
||||
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
|
||||
// We permute at conversion time so the weights work correctly with GGML's kernel
|
||||
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
|
||||
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
|
||||
// Get config values for permutation
|
||||
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
|
||||
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
|
||||
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
|
||||
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
|
||||
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
|
||||
|
||||
// Use appropriate head count: nHeads for Q, nKVHeads for K
|
||||
effectiveHeads := nHeads
|
||||
if strings.Contains(name, "attn_k.") {
|
||||
effectiveHeads = nKVHeads
|
||||
}
|
||||
|
||||
permutedT := t.Clone()
|
||||
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: permutedT,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Replacements() []string {
|
||||
return []string{
|
||||
// Vision encoder
|
||||
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
|
||||
"model.visual.patch_embed.proj", "v.patch_embd",
|
||||
"model.visual.blocks", "v.blk",
|
||||
"model.visual.post_layernorm", "v.post_ln",
|
||||
"model.visual.downsample", "mm.patch_merger",
|
||||
|
||||
// Vision attention
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"attn.q_norm", "attn_q_norm",
|
||||
"attn.k_norm", "attn_k_norm",
|
||||
|
||||
// Vision norms
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
|
||||
// Vision MLP
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
|
||||
// Merger (multimodal projector)
|
||||
"model.visual.merger.proj", "mm.model.fc",
|
||||
"model.visual.merger.post_projection_norm", "mm.post_norm",
|
||||
"model.visual.merger.gate_proj", "mm.gate",
|
||||
"model.visual.merger.up_proj", "mm.up",
|
||||
"model.visual.merger.down_proj", "mm.down",
|
||||
|
||||
// Language model
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
// Language model attention
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
|
||||
// Language model norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"post_self_attn_layernorm", "post_attn_norm",
|
||||
"post_mlp_layernorm", "post_ffn_norm",
|
||||
|
||||
// Language model MLP (remove mlp. prefix so ffn_* names work)
|
||||
"mlp.gate_up_proj", "ffn_gate_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
}
|
||||
}
|
||||
512
convert/convert_qwen3next.go
Normal file
512
convert/convert_qwen3next.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
|
||||
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
|
||||
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen3next: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
}
|
||||
kv["attention.key_length"] = headDim
|
||||
kv["attention.value_length"] = headDim
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
if q.SharedExpertIntermSize > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
continue
|
||||
}
|
||||
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
headVDim int
|
||||
numKHeads int
|
||||
numVHeads int
|
||||
qkvzDim int
|
||||
qkvOut int
|
||||
gateOut int
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
|
||||
if len(shape) != 2 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
numKHeads := int(q.LinearNumKeyHeads)
|
||||
numVHeads := int(q.LinearNumValueHeads)
|
||||
headKDim := int(q.LinearKeyHeadDim)
|
||||
headVDim := int(q.LinearValueHeadDim)
|
||||
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
if numVHeads%numKHeads != 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
hidden := int(shape[1])
|
||||
vPerHead := headVDim * (numVHeads / numKHeads)
|
||||
qkvzDim := 2*headKDim + 2*vPerHead
|
||||
expectedOut := qkvzDim * numKHeads
|
||||
if int(shape[0]) != expectedOut {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
return qkvzSplitSpec{
|
||||
hidden: hidden,
|
||||
headKDim: headKDim,
|
||||
headVDim: headVDim,
|
||||
numKHeads: numKHeads,
|
||||
numVHeads: numVHeads,
|
||||
qkvzDim: qkvzDim,
|
||||
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
|
||||
gateOut: headVDim * numVHeads,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
|
||||
spec, ok := q.qkvzSpec(t.Shape())
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qkvTensor := t.Clone()
|
||||
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
|
||||
|
||||
gateTensor := t.Clone()
|
||||
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
|
||||
|
||||
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
|
||||
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
|
||||
|
||||
return &ggml.Tensor{
|
||||
Name: qkvName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
|
||||
WriterTo: qkvTensor,
|
||||
}, &ggml.Tensor{
|
||||
Name: gateName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
|
||||
WriterTo: gateTensor,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
|
||||
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := 0
|
||||
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += vPerHead
|
||||
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
|
||||
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
|
||||
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
|
||||
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
|
||||
|
||||
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out tensor.Tensor
|
||||
if extractGate {
|
||||
out = zMat
|
||||
} else {
|
||||
out, err = tensor.Concat(1, qMat, kMat, vMat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
out, err = tensor.Transpose(out, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(out.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
|
||||
n, err := n.Add(ones)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
|
||||
// Full attention (self_attn)
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
"linear_attn.A_log", "ssm_a",
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
!strings.HasPrefix(st.name, "mm.") &&
|
||||
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@
|
||||
{
|
||||
"source": "/api",
|
||||
"destination": "/api/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/integrations/clawdbot",
|
||||
"destination": "/integrations/openclaw"
|
||||
}
|
||||
],
|
||||
"navigation": {
|
||||
@@ -103,6 +107,7 @@
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/cline",
|
||||
"/integrations/openclaw",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
|
||||
@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
|
||||
@@ -134,22 +134,12 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
|
||||
50
docs/integrations/openclaw.mdx
Normal file
50
docs/integrations/openclaw.mdx
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: OpenClaw
|
||||
---
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
@@ -9,7 +9,7 @@ OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install.sh | bash
|
||||
curl -fsSL https://opencode.ai/install | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
@@ -201,7 +201,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 4096,
|
||||
"": 0,
|
||||
"2048": 2048,
|
||||
}
|
||||
|
||||
|
||||
@@ -268,8 +268,10 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
@@ -859,11 +861,13 @@ func (f GGML) FlashAttention() bool {
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
|
||||
3
go.sum
3
go.sum
@@ -174,6 +174,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -304,6 +306,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
|
||||
@@ -75,3 +75,10 @@ type Cache interface {
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
|
||||
// CheckpointCache optionally supports restoring recurrent state to a prior
|
||||
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
|
||||
// The returned position is the number of tokens that can be kept (prefix length).
|
||||
type CheckpointCache interface {
|
||||
PrepareRestore(seq int, targetPos int32) (int32, bool)
|
||||
}
|
||||
|
||||
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
@@ -0,0 +1,276 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeffrey Morgan <jmorganca@gmail.com>
|
||||
Date: Tue, 3 Feb 2026 12:00:00 -0800
|
||||
Subject: [PATCH] ggml: metal solve_tri
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
|
||||
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
|
||||
7 files changed, 177 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 680904d13..83385c9ef 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
+ assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
+
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
+
|
||||
+ char base[256];
|
||||
+ char name[256];
|
||||
+
|
||||
+ snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
+ snprintf(name, 256, "%s", base);
|
||||
+
|
||||
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
+ if (!res.pipeline) {
|
||||
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
+ }
|
||||
+
|
||||
+ return res;
|
||||
+}
|
||||
+
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index 0a8b9211a..8a9d17460 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index 7b5ee968c..4e5acfbe5 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ return ggml_is_contiguous(op->src[0]) &&
|
||||
+ ggml_is_contiguous(op->src[1]) &&
|
||||
+ op->src[0]->type == GGML_TYPE_F32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_F32 &&
|
||||
+ op->type == GGML_TYPE_F32;
|
||||
+ case GGML_OP_COUNT_EQUAL:
|
||||
+ return has_simdgroup_reduction &&
|
||||
+ op->src[0]->type == GGML_TYPE_I32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_I32 &&
|
||||
+ op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
index 8944b07e9..cfdea9c07 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
+typedef struct {
|
||||
+ int32_t ne00;
|
||||
+ int32_t ne01;
|
||||
+ int32_t ne02;
|
||||
+ int32_t ne03;
|
||||
+ uint64_t nb00;
|
||||
+ uint64_t nb01;
|
||||
+ uint64_t nb02;
|
||||
+ uint64_t nb03;
|
||||
+ int32_t ne10;
|
||||
+ int32_t ne11;
|
||||
+ uint64_t nb10;
|
||||
+ uint64_t nb11;
|
||||
+ uint64_t nb12;
|
||||
+ uint64_t nb13;
|
||||
+ uint64_t nb0;
|
||||
+ uint64_t nb1;
|
||||
+ uint64_t nb2;
|
||||
+ uint64_t nb3;
|
||||
+} ggml_metal_kargs_solve_tri;
|
||||
+
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 80864f303..4ac135603 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ {
|
||||
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
+ } break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
+ ggml_tensor * op = ctx->node(idx);
|
||||
+
|
||||
+ ggml_metal_library_t lib = ctx->lib;
|
||||
+ ggml_metal_encoder_t enc = ctx->enc;
|
||||
+
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
+
|
||||
+ ggml_metal_kargs_solve_tri args = {
|
||||
+ /*.ne00 =*/ ne00,
|
||||
+ /*.ne01 =*/ ne01,
|
||||
+ /*.ne02 =*/ ne02,
|
||||
+ /*.ne03 =*/ ne03,
|
||||
+ /*.nb00 =*/ nb00,
|
||||
+ /*.nb01 =*/ nb01,
|
||||
+ /*.nb02 =*/ nb02,
|
||||
+ /*.nb03 =*/ nb03,
|
||||
+ /*.ne10 =*/ ne10,
|
||||
+ /*.ne11 =*/ ne11,
|
||||
+ /*.nb10 =*/ nb10,
|
||||
+ /*.nb11 =*/ nb11,
|
||||
+ /*.nb12 =*/ nb12,
|
||||
+ /*.nb13 =*/ nb13,
|
||||
+ /*.nb0 =*/ nb0,
|
||||
+ /*.nb1 =*/ nb1,
|
||||
+ /*.nb2 =*/ nb2,
|
||||
+ /*.nb3 =*/ nb3,
|
||||
+ };
|
||||
+
|
||||
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
+
|
||||
+ const int64_t ncols = ne10;
|
||||
+ const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
+ const int64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ int nth = 64;
|
||||
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
+ if (nth < 1) {
|
||||
+ nth = 1;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
+
|
||||
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
+
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
+
|
||||
+ return 1;
|
||||
+}
|
||||
+
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
index 902b54452..a475183d3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d33c16079..c37447a10 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
+kernel void kernel_solve_tri_f32(
|
||||
+ constant ggml_metal_kargs_solve_tri & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device char * dst,
|
||||
+ uint tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort ntg[[threads_per_threadgroup]]) {
|
||||
+ const uint64_t ncols = (uint64_t) args.ne10;
|
||||
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
+ const uint64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
+ if (gid >= nr) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
+ const uint64_t i02 = rem / ncols;
|
||||
+ const uint64_t i01 = rem - i02 * ncols;
|
||||
+
|
||||
+ const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
+ const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
+ const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
+ const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
+ const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
+ const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
+ const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
+ const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
+ const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
+ const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
+
|
||||
+ device const float * A = (device const float *) src0;
|
||||
+ device const float * B = (device const float *) src1;
|
||||
+ device float * X = (device float *) dst;
|
||||
+
|
||||
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
+
|
||||
+ const uint64_t n = (uint64_t) args.ne11;
|
||||
+
|
||||
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
+ float sum = 0.0f;
|
||||
+ for (uint64_t t = 0; t < i00; ++t) {
|
||||
+ sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
+ X[X_base + t * sx1 + i01 * sx0];
|
||||
+ }
|
||||
+
|
||||
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
+ X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -80,6 +80,7 @@ type LlamaServer interface {
|
||||
GetPort() int
|
||||
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
|
||||
HasExited() bool
|
||||
ContextLength() int
|
||||
}
|
||||
|
||||
// llmServer is an instance of a runner hosting a single model
|
||||
@@ -1200,7 +1201,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do load request: %w", err)
|
||||
slog.Error("do load request", "error", err)
|
||||
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -1901,6 +1903,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *llmServer) ContextLength() int {
|
||||
return s.options.NumCtx
|
||||
}
|
||||
|
||||
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, s)
|
||||
if err != nil {
|
||||
|
||||
@@ -170,6 +170,7 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
GELU_ERF(ctx Context) Tensor
|
||||
QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
@@ -206,6 +207,32 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Exp(ctx Context) Tensor
|
||||
Neg(ctx Context) Tensor
|
||||
|
||||
// Clamp clamps values to [min, max] range
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
|
||||
// Softplus computes ln(1 + exp(x))
|
||||
Softplus(ctx Context) Tensor
|
||||
|
||||
// CumSum computes cumulative sum along dimension 0
|
||||
CumSum(ctx Context) Tensor
|
||||
|
||||
// Diag creates a diagonal matrix from a 1D tensor
|
||||
Diag(ctx Context) Tensor
|
||||
|
||||
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
|
||||
Tri(ctx Context, triType int) Tensor
|
||||
|
||||
// Fill fills a tensor with a constant value (in-place)
|
||||
Fill(ctx Context, value float32) Tensor
|
||||
|
||||
// Repeat4D repeats tensor to match target shape
|
||||
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
|
||||
|
||||
// SolveTri solves a triangular system Ax = B
|
||||
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
|
||||
|
||||
sched := C.ggml_backend_sched_new_ext(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
@@ -1581,6 +1581,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
if len(t2) > 0 {
|
||||
@@ -1772,6 +1779,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
op->src[1]->type == GGML_TYPE_I32 &&
|
||||
op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
|
||||
@@ -2385,6 +2385,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_solve_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
|
||||
const int64_t ncols = ne10;
|
||||
const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
const int64_t nr = n_batches * ncols;
|
||||
|
||||
int nth = 64;
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
if (nth < 1) {
|
||||
nth = 1;
|
||||
}
|
||||
|
||||
const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -56,6 +56,18 @@ type fakeTensor struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// Stub methods to satisfy ml.Tensor interface
|
||||
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
return &fakeTensor{Name: name}
|
||||
|
||||
174
model/models/glmocr/imageprocessor.go
Normal file
174
model/models/glmocr/imageprocessor.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
spatialMergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
|
||||
// Read normalization values from config if available, otherwise use CLIP defaults
|
||||
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
|
||||
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
|
||||
|
||||
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
|
||||
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
|
||||
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
|
||||
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 336)),
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
|
||||
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
|
||||
factor: patchSize * spatialMergeSize,
|
||||
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
|
||||
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
temporalFactor := p.temporalPatchSize
|
||||
numFrames := temporalFactor // single image
|
||||
|
||||
if height < factor || width < factor {
|
||||
// Scale up small images
|
||||
scale := float64(factor) / float64(min(height, width))
|
||||
height = int(math.Ceil(float64(height) * scale))
|
||||
width = int(math.Ceil(float64(width) * scale))
|
||||
}
|
||||
|
||||
if temporalFactor <= 0 {
|
||||
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
|
||||
temporalFactor = 1
|
||||
}
|
||||
if numFrames < temporalFactor {
|
||||
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
|
||||
numFrames = temporalFactor
|
||||
}
|
||||
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
|
||||
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
|
||||
}
|
||||
|
||||
round := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
|
||||
|
||||
if tBar*hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if tBar*hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
|
||||
|
||||
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
|
||||
// We keep [C, H, W] for patch extraction
|
||||
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
|
||||
|
||||
// Calculate grid dimensions (after Conv2D patching)
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // Single image
|
||||
ImageHeight: resizedHeight,
|
||||
ImageWidth: resizedWidth,
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return patches, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := 3
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.spatialMergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
patchIndex := 0
|
||||
|
||||
// Single temporal frame handling (copies to all frames)
|
||||
for range grid.Temporal {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
srcIdx := c*height*width + y*width + x
|
||||
dstIdx := channelOffset + (py * patchSize) + px
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
|
||||
if temporalPatchSize > 1 {
|
||||
frameSize := patchSize * patchSize
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||
result[channelOffset:channelOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
235
model/models/glmocr/model.go
Normal file
235
model/models/glmocr/model.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
|
||||
PatchMerger *PatchMerger `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageStartTokenID int32
|
||||
imageEndTokenID int32
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
|
||||
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: allEOS,
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: int32(c.Uint("image_token_id", 59280)),
|
||||
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
|
||||
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Blocks) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create pixel values tensor from flattened patches
|
||||
// Shape: [patchDim, numPatches]
|
||||
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||
|
||||
// Forward through vision encoder
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||
|
||||
// Forward through downsample (patch merger)
|
||||
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
|
||||
return nil, errors.New("glmocr: missing vision downsample weights")
|
||||
}
|
||||
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
|
||||
|
||||
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
|
||||
if m.PatchMerger == nil {
|
||||
return nil, errors.New("glmocr: missing patch merger weights")
|
||||
}
|
||||
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
// Reset position cache
|
||||
m.TextModel.positionCache = m.TextModel.positionCache[:0]
|
||||
m.TextModel.ropeDelta = 0
|
||||
|
||||
pos := int32(0)
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Get grid info for position calculation
|
||||
grid := inp.Multimodal[0].Data.(*Grid)
|
||||
mergedH := grid.Height / m.VisionModel.spatialMergeSize
|
||||
mergedW := grid.Width / m.VisionModel.spatialMergeSize
|
||||
|
||||
// Add image start token
|
||||
result = append(result, &input.Input{Token: m.imageStartTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
|
||||
// Add image tokens with multimodal data
|
||||
// All image tokens share the same base position for temporal dimension
|
||||
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||
basePos := pos
|
||||
sameBatch := tokensPerGrid - 1
|
||||
if sameBatch < 0 {
|
||||
sameBatch = 0
|
||||
}
|
||||
result = append(result, &input.Input{
|
||||
Token: m.imageTokenID,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: sameBatch,
|
||||
})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
|
||||
// Add placeholder tokens for remaining positions
|
||||
// All image tokens use the same base position (temporal stays constant)
|
||||
for range tokensPerGrid - 1 {
|
||||
result = append(result, &input.Input{Token: m.imageTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
}
|
||||
|
||||
// Advance position by max(mergedH, mergedW) after image tokens
|
||||
pos = basePos + int32(max(mergedH, mergedW))
|
||||
|
||||
// Add image end token
|
||||
result = append(result, &input.Input{Token: m.imageEndTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
}
|
||||
|
||||
// Compute rope delta for continuation after the prefill segment:
|
||||
// delta = (max_position_id + 1) - sequence_length
|
||||
if len(m.TextModel.positionCache) > 0 {
|
||||
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
|
||||
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
ctx.Forward(hiddenStates)
|
||||
|
||||
// Build position slices for M-RoPE
|
||||
positionSlice := func() [][]int32 {
|
||||
s := [][]int32{
|
||||
make([]int32, len(batch.Positions)), // temporal
|
||||
make([]int32, len(batch.Positions)), // height
|
||||
make([]int32, len(batch.Positions)), // width
|
||||
make([]int32, len(batch.Positions)), // unused (zeros)
|
||||
}
|
||||
for i, position := range batch.Positions {
|
||||
// Translate through position cache or continue sequence
|
||||
if position < int32(len(m.TextModel.positionCache)) {
|
||||
position = m.TextModel.positionCache[position]
|
||||
} else if len(m.TextModel.positionCache) > 0 {
|
||||
// Continue sequence after cached positions using ropeDelta
|
||||
position = position + m.TextModel.ropeDelta
|
||||
}
|
||||
|
||||
s[0][i] = position
|
||||
s[1][i] = position
|
||||
s[2][i] = position
|
||||
}
|
||||
return s
|
||||
}()
|
||||
|
||||
// Inject vision embeddings and adjust positions for image tokens
|
||||
for _, mi := range batch.Multimodal {
|
||||
img := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
|
||||
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||
w := grid.Width / m.VisionModel.spatialMergeSize
|
||||
for i := range img.Dim(1) {
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.TextModel.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glmocr", New)
|
||||
}
|
||||
190
model/models/glmocr/model_text.go
Normal file
190
model/models/glmocr/model_text.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type TextModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
headDim int
|
||||
rotaryDim int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
mropeSections []int
|
||||
}
|
||||
|
||||
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
// With 4 sections for [temporal, height, width, unused]
|
||||
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Separate Q, K, V projections
|
||||
q := sa.Query.Forward(ctx, hiddenStates)
|
||||
k := sa.Key.Forward(ctx, hiddenStates)
|
||||
v := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape for GQA
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Apply M-RoPE (multi-resolution rotary position embeddings)
|
||||
q = opts.applyMRoPE(ctx, q, positions)
|
||||
k = opts.applyMRoPE(ctx, k, positions)
|
||||
|
||||
// Scaled dot-product attention with KV cache
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
|
||||
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
|
||||
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type TextDecoderLayer struct {
|
||||
// Input layernorm (before attention)
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
// Post self-attention layernorm (after attention, before residual add)
|
||||
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
|
||||
|
||||
// FFN input layernorm (after first residual, before MLP)
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
// Post MLP layernorm (after MLP, before residual add)
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
|
||||
}
|
||||
|
||||
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
// Attention block
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Prune to output positions in final layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// MLP block
|
||||
residual = hiddenStates
|
||||
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []TextDecoderLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextModelOptions
|
||||
|
||||
// positionCache stores the M-RoPE position for each token in the sequence.
|
||||
// This is needed because image tokens share the same base position but have
|
||||
// different height/width offsets, and the end token position depends on the
|
||||
// image grid dimensions.
|
||||
positionCache []int32
|
||||
ropeDelta int32
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// Clear position cache when KV cache shifts
|
||||
m.positionCache = nil
|
||||
m.ropeDelta = 0
|
||||
return m.applyMRoPE(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
hiddenSize := int(c.Uint("embedding_length", 1536))
|
||||
numHeads := int(c.Uint("attention.head_count", 16))
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
|
||||
intermediateSize := int(c.Uint("feed_forward_length", 4608))
|
||||
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
|
||||
ropeBase := c.Float("rope.freq_base", 10000)
|
||||
|
||||
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
|
||||
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
|
||||
if ropeDim <= 0 {
|
||||
ropeDim = headDim
|
||||
}
|
||||
|
||||
mropeSections := c.Ints("rope.mrope_section")
|
||||
var sectionInts []int
|
||||
|
||||
if len(mropeSections) > 0 {
|
||||
sectionInts = make([]int, len(mropeSections))
|
||||
for i, section := range mropeSections {
|
||||
sectionInts[i] = int(section)
|
||||
}
|
||||
} else {
|
||||
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
|
||||
// For rotaryDim=64 this yields [8, 12, 12].
|
||||
total := ropeDim / 2
|
||||
if total <= 0 {
|
||||
total = 32
|
||||
}
|
||||
s0 := total * 2 / 8
|
||||
s1 := total * 3 / 8
|
||||
s2 := total - s0 - s1
|
||||
sectionInts = []int{s0, s1, s2}
|
||||
}
|
||||
|
||||
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
|
||||
rotaryDim := ropeDim
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
|
||||
TextModelOptions: &TextModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
rotaryDim: rotaryDim,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
ropeBase: ropeBase,
|
||||
mropeSections: sectionInts,
|
||||
},
|
||||
}
|
||||
}
|
||||
355
model/models/glmocr/model_vision.go
Normal file
355
model/models/glmocr/model_vision.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type Grid struct {
|
||||
Height int // Number of patches in height direction
|
||||
Width int // Number of patches in width direction
|
||||
Temporal int
|
||||
ImageHeight int // Full image height in pixels
|
||||
ImageWidth int // Full image width in pixels
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
imageSize int
|
||||
spatialMergeSize int
|
||||
outHiddenSize int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionPatchEmbed struct {
|
||||
Proj *nn.Conv2D `gguf:"patch_embd_0"`
|
||||
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
Bias ml.Tensor `gguf:"patch_embd.bias"`
|
||||
}
|
||||
|
||||
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
_ = grid // patches are already in merge-block order
|
||||
|
||||
// pixelValues shape: [patchDim, numPatches]
|
||||
numPatches := pixelValues.Shape()[1]
|
||||
|
||||
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
|
||||
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Slice temporal frames for Conv2D (simulate Conv3D)
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
|
||||
s0, s1 := opts.patchSize, opts.patchSize
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||
hiddenStates = hiddenStates.Add(ctx, out1)
|
||||
}
|
||||
|
||||
// Flatten to [hidden_size, num_patches]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
|
||||
|
||||
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
|
||||
if pe.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
QKV *nn.Linear `gguf:"attn_qkv"`
|
||||
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Combined QKV projection: [3*hidden_size, batch_size]
|
||||
qkv := sa.QKV.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
|
||||
// ChunkSections returns views - must make contiguous before further operations
|
||||
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
|
||||
q := chunks[0].Contiguous(ctx)
|
||||
k := chunks[1].Contiguous(ctx)
|
||||
v := chunks[2].Contiguous(ctx)
|
||||
|
||||
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
|
||||
// Apply Q-norm and K-norm after head reshape
|
||||
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
|
||||
q = sa.QNorm.Forward(ctx, q, opts.eps)
|
||||
k = sa.KNorm.Forward(ctx, k, opts.eps)
|
||||
|
||||
// Apply rotary position embeddings with vision-style 2D positions.
|
||||
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
|
||||
// We provide H/W sections and leave the remaining sections empty.
|
||||
ropeFreqBase := float32(10000.0)
|
||||
section := opts.headDim / 4
|
||||
if section <= 0 {
|
||||
section = 1
|
||||
}
|
||||
sections := []int{section, section, 0, 0}
|
||||
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Try flash attention first (ScaledDotProductAttention), fall back to manual
|
||||
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
|
||||
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: vision attention falling back to manual attention",
|
||||
"batchSize", batchSize, "numHeads", opts.numHeads,
|
||||
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
|
||||
|
||||
// Manual attention fallback
|
||||
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
// Attention scores
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
// Attention output: v @ kq (note: v first)
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionBlock struct {
|
||||
Norm1 *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
Norm2 *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Pre-norm architecture
|
||||
residual := hiddenStates
|
||||
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionDownsample struct {
|
||||
*nn.Conv2D
|
||||
}
|
||||
|
||||
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
// Apply spatial downsampling via Conv2D
|
||||
// Input: [hidden_size, num_patches] where patches are in merge-block order
|
||||
|
||||
if d.Conv2D == nil || d.Weight == nil {
|
||||
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
|
||||
return hiddenStates // Return input unchanged as fallback
|
||||
}
|
||||
|
||||
merge := opts.spatialMergeSize
|
||||
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
|
||||
|
||||
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
|
||||
|
||||
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
|
||||
// ggml semantics: result.ne[perm[i]] = input.ne[i]
|
||||
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
// Step 3: Apply Conv2D without bias (bias added after reshape)
|
||||
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
|
||||
s0, s1 := merge, merge
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
|
||||
|
||||
// Step 5: Add bias after reshape
|
||||
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
|
||||
if d.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type PatchMerger struct {
|
||||
// GGUF tags align with mm.* keys used by the model
|
||||
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
|
||||
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
|
||||
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
|
||||
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
|
||||
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
|
||||
}
|
||||
|
||||
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Linear projection
|
||||
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
|
||||
|
||||
// Post-projection layer norm + GELU ERF
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.GELU_ERF(ctx)
|
||||
// Force a copy to avoid in-place mutation issues with GELU_ERF
|
||||
hiddenStates = hiddenStates.Contiguous(ctx)
|
||||
|
||||
// SwiGLU MLP: down(silu(gate(x)) * up(x))
|
||||
gateOut := m.GateProj.Forward(ctx, hiddenStates)
|
||||
upOut := m.UpProj.Forward(ctx, hiddenStates)
|
||||
gate := gateOut.SILU(ctx, upOut)
|
||||
return m.DownProj.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbed *VisionPatchEmbed
|
||||
Blocks []VisionBlock `gguf:"blk"`
|
||||
PostLN *nn.RMSNorm `gguf:"post_ln"`
|
||||
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||
// Extract patch embeddings from flattened patches
|
||||
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
|
||||
|
||||
// Create position IDs for RoPE (spatial grid)
|
||||
// Patches are already in merge-block order from preprocessing
|
||||
positions := m.createPositions(ctx, grid)
|
||||
|
||||
// Process through vision blocks
|
||||
for _, block := range m.Blocks {
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
// Post-layernorm
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
// Note: Downsample is now applied separately in Model.EncodeMultimodal
|
||||
// so mm.patch_merger remains a distinct module
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// Create spatial position IDs for vision RoPE
|
||||
// Position layout: [height, width, height, width] - 4 sections for mrope
|
||||
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
|
||||
// This follows the GLM-OCR rot_pos_emb layout
|
||||
numPatches := grid.Height * grid.Width
|
||||
mergeRatio := m.spatialMergeSize
|
||||
|
||||
// Build position arrays in merge-block order
|
||||
// Each merge_ratio x merge_ratio block of patches is grouped together
|
||||
hpos := make([]int32, numPatches)
|
||||
wpos := make([]int32, numPatches)
|
||||
ptr := 0
|
||||
for y := 0; y < grid.Height; y += mergeRatio {
|
||||
for x := 0; x < grid.Width; x += mergeRatio {
|
||||
for dy := range mergeRatio {
|
||||
for dx := range mergeRatio {
|
||||
hpos[ptr] = int32(y + dy)
|
||||
wpos[ptr] = int32(x + dx)
|
||||
ptr++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
|
||||
// keep remaining sections zeroed to match its conventions.
|
||||
zeros := make([]int32, numPatches)
|
||||
s := [][]int32{
|
||||
hpos, // Section 0: height
|
||||
wpos, // Section 1: width
|
||||
zeros, // Section 2: unused
|
||||
zeros, // Section 3: unused
|
||||
}
|
||||
|
||||
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
imageSize := int(c.Uint("vision.image_size", 336))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
|
||||
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
|
||||
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
|
||||
|
||||
return &VisionModel{
|
||||
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: hiddenSize / numHeads,
|
||||
numChannels: numChannels,
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
imageSize: imageSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
outHiddenSize: outHiddenSize,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
@@ -19,5 +20,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3next"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
103
model/models/qwen3next/attention.go
Normal file
103
model/models/qwen3next/attention.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the attention layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
|
||||
|
||||
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
|
||||
// Key differences from standard attention:
|
||||
// - Q projection outputs 2x size (Q + gate interleaved)
|
||||
// - Both Q and K have RMSNorm
|
||||
// - Output is gated: attn * sigmoid(gate)
|
||||
type FullAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
// Use Dim() instead of Shape() for consistent behavior during graph construction
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
|
||||
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 0 {
|
||||
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
||||
if batchSize != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||
batchSize = seqTokens * seqs
|
||||
} else if batchSize != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
}
|
||||
headDim := opts.headDim()
|
||||
numHeads := opts.numHeads
|
||||
|
||||
// Q projection outputs query + gate interleaved
|
||||
qFull := sa.Query.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape to [headDim * 2, numHeads, batchSize]
|
||||
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
|
||||
|
||||
// Split Q and gate along dimension 0
|
||||
// Q: first headDim elements, gate: second headDim elements
|
||||
query := qFull.Slice(ctx, 0, 0, headDim, 1)
|
||||
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
|
||||
|
||||
// Make query contiguous for further operations
|
||||
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K and V projections
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Derive numKVHeads from tensor dimensions (per-layer value)
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Apply QK normalization
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// Apply RoPE
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
// Standard attention computation
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Apply sigmoid gate
|
||||
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
|
||||
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
|
||||
gateSigmoid := gate.Sigmoid(ctx)
|
||||
attention = attention.Mul(ctx, gateSigmoid)
|
||||
|
||||
return sa.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
596
model/models/qwen3next/cache.go
Normal file
596
model/models/qwen3next/cache.go
Normal file
@@ -0,0 +1,596 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for full attention layers
|
||||
// - per-sequence conv state for linear attention layers
|
||||
// - per-sequence delta state for linear attention layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int // convKernelSize - 1
|
||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
||||
|
||||
// Delta state dimensions
|
||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer delta state buffers (allocated lazily)
|
||||
deltaCtxs map[int]ml.Context
|
||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointDeltaCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(
|
||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||
convDim, convChannels, deltaStateSize int,
|
||||
) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
convDim: convDim,
|
||||
convChannels: convChannels,
|
||||
deltaStateSize: deltaStateSize,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
deltaCtxs: make(map[int]ml.Context),
|
||||
deltaStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: checkpointCountDefault,
|
||||
checkpointMinPos: checkpointMinPosDefault,
|
||||
checkpointInterval: checkpointIntervalDefault,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.deltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointDeltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
c.reserveCheckpoints = true
|
||||
c.planCheckpoints(batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero state for newly allocated slots
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = false
|
||||
c.planCheckpoints(batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Zero conv states
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// Zero delta states
|
||||
if len(c.deltaStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
||||
for _, buf := range c.deltaStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.deltaStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// Copy-on-write for recurrent state
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
// If the recurrent slot is shared, detach it before applying a restore.
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
// Removal invalidates recurrent state
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.deltaStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent delta state must stay in F32.
|
||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
||||
c.deltaStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
498
model/models/qwen3next/checkpoints.go
Normal file
498
model/models/qwen3next/checkpoints.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
checkpointCountDefault = 32
|
||||
checkpointMinPosDefault = int32(16)
|
||||
checkpointIntervalDefault = int32(1280)
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
delta map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.delta {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.deltaStates {
|
||||
if entry.delta == nil || entry.delta[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.delta != nil {
|
||||
if dstEntry.delta == nil {
|
||||
dstEntry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.delta {
|
||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.delta == nil {
|
||||
entry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.delta[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointDeltaCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
||||
entry.delta[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointDelta(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
300
model/models/qwen3next/checkpoints_test.go
Normal file
300
model/models/qwen3next/checkpoints_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestBackend(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
_ = f.Close()
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 1, 1)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
||||
backend := newTestBackend(t)
|
||||
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.slotForSeq[2] = 0
|
||||
cache.refCount[0] = 2
|
||||
cache.refCount[1] = 0
|
||||
cache.freeSlots = []int{1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[1] != 1 {
|
||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[2] != 0 {
|
||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
||||
}
|
||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
||||
}
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
// Only set conv checkpoint, not delta - making it incomplete
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.delta is not set, so checkpoint is incomplete
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Don't set convStates/deltaStates - with no layers to check,
|
||||
// entryComplete will return true as long as entry.pos >= 0
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
// Test that restoreComplete returns true when no layers need checkpoints
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
// Fill the buffer
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Create fake tensor data in the first entry's maps
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
||||
|
||||
// Record another entry, which should wrap around and overwrite entry 0
|
||||
store.record(40)
|
||||
|
||||
// Verify the maps are still present (we reuse tensors)
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].delta == nil {
|
||||
t.Fatalf("expected delta map to be preserved on reuse")
|
||||
}
|
||||
|
||||
// Verify the new position was recorded
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
// Test behavior when buffer is exactly at capacity
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
// Verify both checkpoints are accessible
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
// Test behavior with zero-size buffer
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
// Test pruning that removes all checkpoints
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Prune everything by setting threshold below all positions
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
// When all checkpoints are pruned, lastPos is reset to -1
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
472
model/models/qwen3next/deltanet.go
Normal file
472
model/models/qwen3next/deltanet.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const chunkSize = 64
|
||||
|
||||
// TriType constants for triangular matrix operations
|
||||
const (
|
||||
TriTypeUpperDiag = 0
|
||||
TriTypeUpper = 1
|
||||
TriTypeLowerDiag = 2
|
||||
TriTypeLower = 3
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Masks holds pre-computed mask tensors for chunked attention
|
||||
type Masks struct {
|
||||
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
|
||||
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
|
||||
Diag ml.Tensor // causal + identity
|
||||
}
|
||||
|
||||
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
|
||||
// It implements the Operator interface directly.
|
||||
type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
// Layer index for cache access (set during model construction)
|
||||
Layer int
|
||||
}
|
||||
|
||||
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
||||
func createMasks(ctx ml.Context) *Masks {
|
||||
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
||||
ones = ones.Fill(ctx, 1.0)
|
||||
causalMask := ones.Tri(ctx, TriTypeLower)
|
||||
|
||||
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
|
||||
onesVec = onesVec.Fill(ctx, 1.0)
|
||||
identity := onesVec.Diag(ctx)
|
||||
|
||||
diagMask := causalMask.Add(ctx, identity)
|
||||
|
||||
return &Masks{
|
||||
Causal: causalMask,
|
||||
Identity: identity,
|
||||
Diag: diagMask,
|
||||
}
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := gdn.Layer
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2)
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 1 {
|
||||
if nSeqTokens != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
} else {
|
||||
if nSeqTokens != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
||||
nSeqTokens = seqTokens
|
||||
nSeqs = seqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headKDim := opts.ssmDState
|
||||
numKHeads := opts.ssmNGroup
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Split beta and alpha
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Reshape to merge head dimensions
|
||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
// Log this - if it happens, short-term context will be lost
|
||||
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states
|
||||
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
|
||||
|
||||
// Concatenate with input for convolution
|
||||
convInput := convStates.Concat(ctx, qkvMixed, 0)
|
||||
|
||||
// Save new conv state (last convKernelSize-1 tokens)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution (kernel must be F32 for Metal)
|
||||
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
||||
convOutput = convOutput.SILU(ctx)
|
||||
|
||||
// Reshape for extraction
|
||||
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
|
||||
|
||||
// Extract convolved Q, K, V
|
||||
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
|
||||
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
|
||||
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
|
||||
|
||||
// Reshape to 4D
|
||||
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Get delta state from cache
|
||||
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||
if err != nil {
|
||||
// Log this - if it happens frequently, context will degrade
|
||||
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
|
||||
}
|
||||
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
|
||||
|
||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||
if numKHeads != numVHeads {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
var attnOut ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
}
|
||||
|
||||
// Apply gated normalization
|
||||
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
|
||||
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
|
||||
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
|
||||
zSilu := z2D.SILU(ctx)
|
||||
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
|
||||
|
||||
// Reshape for output projection
|
||||
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
out := gdn.SSMOut.Forward(ctx, finalOutput)
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
|
||||
// deltaNetAutoregressive implements single-token state update.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Reshape gate and beta for broadcasting
|
||||
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
|
||||
// Apply exponential to gate
|
||||
gT = gT.Exp(ctx)
|
||||
|
||||
// state = state * g_t
|
||||
state = state.Mul(ctx, gT)
|
||||
|
||||
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
kvMem := state.Mul(ctx, kTUnsqueezed)
|
||||
// Sum over dim=-2 (second dimension after permute)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kvMem = kvMem.SumRows(ctx)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// v_t with singleton dimension
|
||||
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
|
||||
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
vDiff := vT.Sub(ctx, kvMem)
|
||||
delta := vDiff.Mul(ctx, betaT)
|
||||
|
||||
// state = state + k_t.unsqueeze(-1) * delta
|
||||
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
|
||||
state = state.Add(ctx, kTDelta)
|
||||
|
||||
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
stateQ := state.Mul(ctx, qTUnsqueezed)
|
||||
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
coreAttnOut := stateQ.SumRows(ctx)
|
||||
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
||||
}
|
||||
|
||||
// deltaNetChunked implements chunked computation for prefill.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
masks *Masks,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
headKDim := q.Dim(0)
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nTokens := q.Dim(2)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Permute tensors for chunked computation
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
|
||||
|
||||
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Compute padding
|
||||
pad := (chunkSize - nTokens%chunkSize) % chunkSize
|
||||
nChunks := (nTokens + pad) / chunkSize
|
||||
|
||||
// Pad tensors
|
||||
if pad > 0 {
|
||||
q = q.Pad(ctx, 0, pad, 0, 0)
|
||||
k = k.Pad(ctx, 0, pad, 0, 0)
|
||||
v = v.Pad(ctx, 0, pad, 0, 0)
|
||||
gate = gate.Pad(ctx, pad, 0, 0, 0)
|
||||
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
||||
}
|
||||
|
||||
// Use pre-computed masks (passed in, not recreated)
|
||||
causalMask := masks.Causal
|
||||
identity := masks.Identity
|
||||
diagMask := masks.Diag
|
||||
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
|
||||
|
||||
// v_beta = v * beta, k_beta = k * beta
|
||||
vBeta := v.Mul(ctx, beta)
|
||||
kBeta := k.Mul(ctx, beta)
|
||||
|
||||
// Reshape for chunked computation
|
||||
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
|
||||
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
|
||||
// g_cumsum = cumsum(gate)
|
||||
gCumsum := gate.CumSum(ctx)
|
||||
|
||||
// Compute decay mask
|
||||
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
decayMask := gcsBroadcast.Sub(ctx, gcsI)
|
||||
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
decayMask = decayMask.Exp(ctx)
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
|
||||
// k @ k_beta^T
|
||||
kMulKBeta := k.Mulmat(ctx, kBeta)
|
||||
|
||||
// k_decay = k @ k_beta^T * decay_mask
|
||||
kDecay := kMulKBeta.Mul(ctx, decayMask)
|
||||
|
||||
// attn = -k_decay * causal_mask
|
||||
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
|
||||
|
||||
// Triangular solve: (I - attn_lower)^-1 @ attn
|
||||
attnLower := attn.Mul(ctx, causalMask)
|
||||
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
|
||||
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
|
||||
attn = linSolve.Mul(ctx, causalMask)
|
||||
attn = attn.Add(ctx, identity4D)
|
||||
|
||||
// v = v_beta^T @ attn
|
||||
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
v = vBetaT.Mulmat(ctx, attn)
|
||||
|
||||
// Compute g_exp for state update
|
||||
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
gExp := gCumsumT.Exp(ctx)
|
||||
|
||||
// kbeta_gexp = k_beta * g_exp
|
||||
kBetaGExp := kBeta.Mul(ctx, gExp)
|
||||
|
||||
// k_cumdecay = attn @ kbeta_gexp^T
|
||||
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
|
||||
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
|
||||
attnKQ := k.Mulmat(ctx, q)
|
||||
attnKQ = attnKQ.Mul(ctx, decayMask)
|
||||
attnKQ = attnKQ.Mul(ctx, diagMask)
|
||||
|
||||
// Pre-compute g_last and key_gdiff
|
||||
// g_last = view of last element in g_cumsum along chunk_size dimension
|
||||
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
|
||||
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
|
||||
gLastExp := gLast.Exp(ctx)
|
||||
|
||||
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
|
||||
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
||||
gDiffExp := gDiff.Exp(ctx)
|
||||
|
||||
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
|
||||
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Process chunks and update state
|
||||
var coreAttnOut ml.Tensor
|
||||
newState := state
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
||||
|
||||
// state^T - permute is needed but Contiguous creates a copy
|
||||
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// v_prime = k_cumdecay @ state
|
||||
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
||||
|
||||
// v_new = v - v_prime
|
||||
vNew := vChunk.Sub(ctx, vPrime)
|
||||
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// attn_inter = (q * g_exp) @ state
|
||||
qGExp := qChunk.Mul(ctx, gExpChunk)
|
||||
attnInter := stateT.Mulmat(ctx, qGExp)
|
||||
|
||||
// core_attn_out = attn_inter + attn @ v_new
|
||||
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
if coreAttnOut == nil {
|
||||
coreAttnOut = coreAttnOutChunk
|
||||
} else {
|
||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||
}
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||
|
||||
// state = state * g_last + kgdmulvnew
|
||||
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
newState = newState.Mul(ctx, gExpLastReshaped)
|
||||
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
||||
}
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
// Slice to remove padding
|
||||
if pad > 0 {
|
||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||
}
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
||||
}
|
||||
383
model/models/qwen3next/model.go
Normal file
383
model/models/qwen3next/model.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
keyLength int
|
||||
valueLength int
|
||||
ropeDim int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeScale float32
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
normTopKProb bool
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
ssmDInner int // d_inner = head_v_dim * num_v_heads
|
||||
ssmDState int // head_k_dim
|
||||
ssmNGroup int // num_k_heads
|
||||
ssmDtRank int // num_v_heads
|
||||
convKernelSize int // SSM conv kernel size
|
||||
|
||||
// Per-layer type from GGUF metadata
|
||||
isRecurrent []bool
|
||||
|
||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||
masks *Masks
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
ropeDim := cmp.Or(o.ropeDim, o.headDim())
|
||||
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
// Operator is the interface for attention-like operators
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// sparse implements MoE with shared experts
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
|
||||
// Shared experts
|
||||
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
|
||||
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
|
||||
// Router logits
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Softmax routing weights
|
||||
routingWeights := routerLogits.Softmax(ctx)
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
|
||||
|
||||
// Expert computation with SILU activation
|
||||
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
experts := gateOut.SILU(ctx, upOut)
|
||||
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if mlp.SharedUp != nil {
|
||||
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
|
||||
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedOut := sharedGate.SILU(ctx, sharedUp)
|
||||
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
|
||||
|
||||
// Apply shared expert gating
|
||||
if mlp.SharedGateInp != nil {
|
||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||
sharedGateVal = sharedGateVal.Sigmoid(ctx)
|
||||
// Broadcast gate to match dimensions
|
||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||
}
|
||||
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
|
||||
// dense implements standard feedforward
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
|
||||
Operator Operator
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-attention norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Attention (full or linear)
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// First residual connection
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// Save for FFN residual
|
||||
ffnResidual := hiddenStates
|
||||
|
||||
// Post-attention norm (before FFN)
|
||||
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// FFN
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
// Second residual connection
|
||||
return hiddenStates.Add(ctx, ffnResidual), nil
|
||||
}
|
||||
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
// Create masks once per forward pass
|
||||
m.Options.masks = createMasks(ctx)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer head counts (for detecting layer type)
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
layers[i].Operator = &GatedDeltaNet{Layer: i}
|
||||
} else {
|
||||
layers[i].Operator = &FullAttention{}
|
||||
}
|
||||
|
||||
if isMoE {
|
||||
layers[i].MLP = &sparse{}
|
||||
} else {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: func() int {
|
||||
for _, v := range headCountKV {
|
||||
if v > 0 {
|
||||
return int(v)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
isRecurrent: isRecurrent,
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.convKernelSize-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
headVDim := 0
|
||||
numVHeads := opts.ssmDtRank
|
||||
if numVHeads > 0 {
|
||||
headVDim = opts.ssmDInner / numVHeads
|
||||
}
|
||||
deltaStateSize := headVDim * headVDim * numVHeads
|
||||
|
||||
// Validate dimension assumption: headKDim == headVDim is required for state computations
|
||||
headKDim := opts.ssmDState
|
||||
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
|
||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
|
||||
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen3next", New)
|
||||
}
|
||||
17
model/parsers/glmocr.go
Normal file
17
model/parsers/glmocr.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GlmOcrParser is the GLM46 parser with thinking disabled.
|
||||
type GlmOcrParser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GlmOcrParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -17,12 +18,34 @@ const (
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
// ministralEvent represents an event emitted during parsing
|
||||
type ministralEvent interface {
|
||||
isMinistralEvent()
|
||||
}
|
||||
|
||||
type ministralEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ministralEventThinking struct {
|
||||
thinking string
|
||||
}
|
||||
|
||||
type ministralEventToolCall struct {
|
||||
name string
|
||||
args string // raw JSON string
|
||||
}
|
||||
|
||||
func (ministralEventContent) isMinistralEvent() {}
|
||||
func (ministralEventThinking) isMinistralEvent() {}
|
||||
func (ministralEventToolCall) isMinistralEvent() {}
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
pendingToolName string // stores tool name while collecting args
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
const (
|
||||
ministralToolCallsTag = "[TOOL_CALLS]"
|
||||
ministralThinkTag = "[THINK]"
|
||||
ministralThinkEndTag = "[/THINK]"
|
||||
ministralArgsTag = "[ARGS]"
|
||||
)
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. The second return value indicates
|
||||
// whether to keep looping (true when state transitions, false when waiting
|
||||
// for more data).
|
||||
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
|
||||
var events []ministralEvent
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Check for [TOOL_CALLS] tag
|
||||
if strings.Contains(bufStr, ministralToolCallsTag) {
|
||||
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for [THINK] tag
|
||||
if strings.Contains(bufStr, ministralThinkTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
|
||||
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
|
||||
overlapThink := overlap(bufStr, ministralThinkTag)
|
||||
maxOverlap := max(overlapToolCalls, overlapThink)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
// Withhold the potential partial tag
|
||||
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
|
||||
trailingWS := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWS
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit content but withhold trailing whitespace
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralThinkEndTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
|
||||
thinkingContent := split[0]
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
p.buffer.WriteString(after)
|
||||
if len(thinkingContent) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: thinkingContent})
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
call := api.ToolCall{
|
||||
// Check for partial overlap with [/THINK]
|
||||
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit all thinking content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolName:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralArgsTag) {
|
||||
split := strings.SplitN(bufStr, ministralArgsTag, 2)
|
||||
toolName := split[0]
|
||||
after := split[1]
|
||||
p.pendingToolName = toolName
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolArgs
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolArgs:
|
||||
bufStr := p.buffer.String()
|
||||
jsonEnd := findJSONEnd(bufStr)
|
||||
|
||||
if jsonEnd != -1 {
|
||||
jsonStr := bufStr[:jsonEnd+1]
|
||||
remaining := bufStr[jsonEnd+1:]
|
||||
|
||||
events = append(events, ministralEventToolCall{
|
||||
name: p.pendingToolName,
|
||||
args: jsonStr,
|
||||
})
|
||||
|
||||
p.pendingToolName = ""
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unexpected ministral event")
|
||||
}
|
||||
}
|
||||
|
||||
// parseEvents loops calling eat() until it returns false
|
||||
func (p *MinistralParser) parseEvents() []ministralEvent {
|
||||
var all []ministralEvent
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []ministralEvent
|
||||
events, keepLooping = p.eat()
|
||||
all = append(all, events...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentBuilder, thinkingBuilder strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case ministralEventContent:
|
||||
contentBuilder.WriteString(e.content)
|
||||
case ministralEventThinking:
|
||||
thinkingBuilder.WriteString(e.thinking)
|
||||
case ministralEventToolCall:
|
||||
// Validate tool exists
|
||||
tool, toolErr := toolByName(p.tools, e.name)
|
||||
if toolErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
|
||||
}
|
||||
// Parse JSON arguments
|
||||
var args api.ToolCallFunctionArguments
|
||||
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Name: tool.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
})
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
// findJSONEnd finds the index of the closing brace that completes a JSON object.
|
||||
// It properly handles nested objects, arrays, and strings (including escaped characters).
|
||||
// Returns -1 if the JSON is not yet complete.
|
||||
func findJSONEnd(s string) int {
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range s {
|
||||
if inString {
|
||||
switch {
|
||||
case escaped:
|
||||
// If the previous character was a backslash, skip this character
|
||||
escaped = false
|
||||
case r == '\\':
|
||||
// Mark the next character as escaped
|
||||
escaped = true
|
||||
case r == '"':
|
||||
// End of string literal
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
// Start of string literal
|
||||
inString = true
|
||||
case '{', '[':
|
||||
// Increase nesting level for objects and arrays
|
||||
depth++
|
||||
case '}', ']':
|
||||
// Decrease nesting level
|
||||
depth--
|
||||
if depth == 0 {
|
||||
// Reached the end of the root JSON structure
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
545
model/parsers/ministral_test.go
Normal file
545
model/parsers/ministral_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMinistralParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []ministralEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
tools []api.Tool
|
||||
steps []step
|
||||
think bool // whether to enable thinking support
|
||||
}{
|
||||
// Content streaming
|
||||
{
|
||||
desc: "simple content",
|
||||
steps: []step{
|
||||
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "Hello, how can I help you?"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming content word by word",
|
||||
steps: []step{
|
||||
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
|
||||
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
|
||||
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Simple tool calls
|
||||
{
|
||||
desc: "simple tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with nested object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with deeply nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with array of objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with escaped quotes in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with braces inside string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty JSON object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "no_args", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON with newlines in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "backslash in string value",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content after tool call
|
||||
{
|
||||
desc: "content after tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// NOTE: It's unclear if this is valid Ministral output, but the parser
|
||||
// currently treats text after a tool call as regular content. This test
|
||||
// documents that behavior so we notice if it changes.
|
||||
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": 1}`},
|
||||
ministralEventContent{content: "some content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Multiple tool calls
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "get_weather"}},
|
||||
{Function: api.ToolFunction{Name: "get_time"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
|
||||
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls streamed separately",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "tool_a"}},
|
||||
{Function: api.ToolFunction{Name: "tool_b"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
|
||||
}},
|
||||
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Streaming tool calls
|
||||
{
|
||||
desc: "streaming tool call with nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
|
||||
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
|
||||
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
|
||||
{input: `]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming with incomplete JSON waits for completion",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
|
||||
{input: `"a": {`, wantEvents: []ministralEvent{}},
|
||||
{input: `"b": 1`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Partial tag handling
|
||||
{
|
||||
desc: "partial tool tag fakeout",
|
||||
steps: []step{
|
||||
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
|
||||
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tag split across chunks",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_", wantEvents: []ministralEvent{}},
|
||||
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "hello"},
|
||||
ministralEventToolCall{name: "get_weather", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace between content and tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tabs and newlines before tool call are trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space before tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
|
||||
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace before THINK tag is trimmed",
|
||||
steps: []step{
|
||||
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing newline withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Thinking support
|
||||
{
|
||||
desc: "thinking content",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking here"},
|
||||
}},
|
||||
{input: "content after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking with whitespace after end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "my thoughts"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space after think end tag is trimmed",
|
||||
think: true,
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space
|
||||
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial think end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag fakeout",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking then tool call",
|
||||
think: true,
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "let me think"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content then THINK tag transition
|
||||
{
|
||||
desc: "content then think tag",
|
||||
steps: []step{
|
||||
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "more"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Unicode handling
|
||||
{
|
||||
desc: "unicode content",
|
||||
steps: []step{
|
||||
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "你好 🌍 مرحبا"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "unicode in tool args",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := MinistralParser{}
|
||||
parser.hasThinkingSupport = tc.think
|
||||
parser.Init(tc.tools, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_Errors(t *testing.T) {
|
||||
t.Run("unknown tool returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown tool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindJSONEnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
input: `{"a": 1}`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
input: `{"a": {"b": 2}}`,
|
||||
expected: 14,
|
||||
},
|
||||
{
|
||||
name: "array inside object",
|
||||
input: `{"items": [1, 2, 3]}`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "braces in string",
|
||||
input: `{"template": "Hello {name}!"}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
input: `{"msg": "say \"hi\""}`,
|
||||
expected: 20,
|
||||
},
|
||||
{
|
||||
name: "incomplete object",
|
||||
input: `{"a": {"b": 1}`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
input: `{"a": {"b": {"c": {"d": 1}}}}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "object with trailing content",
|
||||
input: `{"a": 1} extra`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
input: `[{"a": 1}, {"b": 2}]`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "escaped backslash before quote",
|
||||
input: `{"path": "C:\\"}`,
|
||||
expected: 15,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "no opening brace",
|
||||
input: "hello world",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "only opening brace",
|
||||
input: "{",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "unclosed string",
|
||||
input: `{"key": "unclosed`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "double escaped backslash then quote",
|
||||
input: `{"path": "C:\\\\"}`,
|
||||
expected: 17,
|
||||
},
|
||||
{
|
||||
name: "unicode in key and value",
|
||||
input: `{"키": "값"}`,
|
||||
expected: 13,
|
||||
},
|
||||
{
|
||||
name: "nested arrays",
|
||||
input: `{"matrix": [[1, 2], [3, 4]]}`,
|
||||
expected: 27,
|
||||
},
|
||||
{
|
||||
name: "mixed nesting",
|
||||
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
|
||||
expected: 31,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findJSONEnd(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasToolSupport(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
|
||||
p := &MinistralParser{hasThinkingSupport: false}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
|
||||
p = &MinistralParser{hasThinkingSupport: true}
|
||||
if !p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return true")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package parsers
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
@@ -70,6 +71,8 @@ func ParserForName(name string) Parser {
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrParser{}
|
||||
case "lfm2":
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
@@ -114,3 +117,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
// overlap returns the longest overlap between the suffix of s and the prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
|
||||
109
model/renderers/glmocr.go
Normal file
109
model/renderers/glmocr.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GlmOcrRenderer struct{}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
enableThinking := false
|
||||
thinkingExplicitlySet := false
|
||||
if thinkValue != nil {
|
||||
enableThinking = thinkValue.Bool()
|
||||
thinkingExplicitlySet = true
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>\n")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
|
||||
} else {
|
||||
sb.WriteString("<think></think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + strings.TrimSpace(message.Content))
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>\n")
|
||||
if thinkingExplicitlySet && !enableThinking {
|
||||
sb.WriteString("<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
sb.WriteString(imStartTag + "user")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
sb.WriteString("\n</tool_response>")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -127,8 +128,7 @@ fahrenheit
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -280,8 +279,7 @@ call tool<|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: testArgs(map[string]any{"payload": "ok"}),
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
|
||||
}
|
||||
|
||||
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
|
||||
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
|
||||
}
|
||||
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
|
||||
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ToolDefinitionTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -82,6 +82,8 @@ func rendererForName(name string) Renderer {
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
case "lfm2-thinking":
|
||||
|
||||
@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
if numPast > 0 {
|
||||
// Recurrent caches use checkpoints to pick a safe resume position.
|
||||
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
|
||||
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
|
||||
numPast = restored
|
||||
} else {
|
||||
numPast = 0
|
||||
}
|
||||
} else if !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
}
|
||||
}
|
||||
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
|
||||
@@ -740,7 +740,11 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the sequence was replaced while this batch was computing, discard results.
|
||||
if activeBatch.seqs[i] != seq {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
seq.lastUpdatedAt = t
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
@@ -1358,7 +1362,7 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
|
||||
// Dummy load to get the backend wired up
|
||||
f, err := os.CreateTemp("", "*.bin")
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
@@ -1368,13 +1372,13 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
|
||||
"general.architecture": "llama",
|
||||
"tokenizer.ggml.model": "gpt2",
|
||||
}, nil); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
|
||||
|
||||
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -12,18 +12,18 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
|
||||
@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// Clip images are represented as 768 tokens, each an embedding
|
||||
imageNumTokens := 768
|
||||
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
// always include the last message
|
||||
if i == n {
|
||||
continue
|
||||
}
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(m.Images)
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
}
|
||||
|
||||
currMsgIdx := n
|
||||
if currMsgIdx > 0 {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
model := Model{Template: tmpl}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
maxTokenizes int
|
||||
}{
|
||||
{
|
||||
name: "all messages fit",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 1,
|
||||
},
|
||||
{
|
||||
name: "truncate to last message",
|
||||
limit: 5,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizeCount := 0
|
||||
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
||||
tokenizeCount++
|
||||
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tokenizeCount > tt.maxTokenizes {
|
||||
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,48 @@ func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
switch {
|
||||
// Full attention
|
||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_k.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_v.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".attn_output.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// Linear attention (Gated Delta Net) after split
|
||||
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// SSM
|
||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// MoE experts + shared experts
|
||||
case strings.HasSuffix(name, ".ffn_down_exps.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
|
||||
// Ported from llama_tensor_get_type, removed unsupported quantization types
|
||||
nExperts := max(1, kv.Uint("expert_count", 0))
|
||||
@@ -217,6 +259,7 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
|
||||
|
||||
// do not quantize positional embeddings and token types (BERT)
|
||||
quantize = quantize && (name != "position_embd.weight")
|
||||
@@ -244,6 +287,12 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3nextQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
|
||||
if newType != defaultType {
|
||||
|
||||
@@ -75,16 +75,12 @@ func experimentEnabled(name string) bool {
|
||||
|
||||
var useClient2 = experimentEnabled("client2")
|
||||
|
||||
// Low VRAM mode is based on the sum of total VRAM (not free) and triggers
|
||||
// reduced context length on some models
|
||||
var lowVRAMThreshold uint64 = 20 * format.GibiByte
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -107,8 +103,12 @@ var (
|
||||
errBadTemplate = errors.New("template error")
|
||||
)
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if opts.NumCtx == 0 {
|
||||
opts.NumCtx = s.defaultNumCtx
|
||||
}
|
||||
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
return api.Options{}, err
|
||||
}
|
||||
@@ -140,20 +140,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, requestOpts)
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// This model is much more capable with a larger context, so set that
|
||||
// unless it would penalize performance too much
|
||||
if !s.lowVRAM && slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, model.Config.ModelFamily) {
|
||||
opts.NumCtx = max(opts.NumCtx, 8192)
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
@@ -1720,10 +1711,18 @@ func Serve(ln net.Listener) error {
|
||||
for _, gpu := range gpus {
|
||||
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
||||
}
|
||||
if totalVRAM < lowVRAMThreshold {
|
||||
s.lowVRAM = true
|
||||
slog.Info("entering low vram mode", "total vram", format.HumanBytes2(totalVRAM), "threshold", format.HumanBytes2(lowVRAMThreshold))
|
||||
|
||||
// Set default context based on VRAM tier
|
||||
// Use slightly lower thresholds (47/23 GiB vs. 48/24 GiB) to account for small differences in the exact value
|
||||
switch {
|
||||
case totalVRAM >= 47*format.GibiByte:
|
||||
s.defaultNumCtx = 262144
|
||||
case totalVRAM >= 23*format.GibiByte:
|
||||
s.defaultNumCtx = 32768
|
||||
default:
|
||||
s.defaultNumCtx = 4096
|
||||
}
|
||||
slog.Info("vram-based default context", "total_vram", format.HumanBytes2(totalVRAM), "default_num_ctx", s.defaultNumCtx)
|
||||
|
||||
err = srvr.Serve(ln)
|
||||
// If server is closed from the signal handler, wait for the ctx to be done
|
||||
@@ -1897,8 +1896,8 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
Details: modelDetails,
|
||||
ExpiresAt: v.expiresAt,
|
||||
}
|
||||
if v.Options != nil {
|
||||
mr.ContextLength = v.Options.NumCtx
|
||||
if v.llama != nil {
|
||||
mr.ContextLength = v.llama.ContextLength()
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -208,6 +209,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChatDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
|
||||
// when in chat-like flow (messages present, no suffix, no template)
|
||||
func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -204,6 +205,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
|
||||
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
|
||||
func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
|
||||
@@ -162,6 +162,7 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -878,6 +879,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -2355,6 +2357,7 @@ func TestGenerateWithImages(t *testing.T) {
|
||||
// TestImageGenerateStreamFalse tests that image generation respects stream=false
|
||||
// and returns a single JSON response instead of streaming ndjson.
|
||||
func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
|
||||
127
server/routes_options_test.go
Normal file
127
server/routes_options_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelOptionsNumCtxPriority(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envContextLen string // empty means not set (uses 0 sentinel)
|
||||
defaultNumCtx int // VRAM-based default
|
||||
modelNumCtx int // 0 means not set in model
|
||||
requestNumCtx int // 0 means not set in request
|
||||
expectedNumCtx int
|
||||
}{
|
||||
{
|
||||
name: "vram default when nothing else set",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 32768,
|
||||
},
|
||||
{
|
||||
name: "env var overrides vram default",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 8192,
|
||||
},
|
||||
{
|
||||
name: "model overrides vram default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 16384,
|
||||
},
|
||||
{
|
||||
name: "model overrides env var",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 16384,
|
||||
},
|
||||
{
|
||||
name: "request overrides everything",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "request overrides vram default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "request overrides model",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "low vram tier default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 4096,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "high vram tier default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 262144,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 262144,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set or clear environment variable
|
||||
if tt.envContextLen != "" {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
|
||||
}
|
||||
|
||||
// Create server with VRAM-based default
|
||||
s := &Server{
|
||||
defaultNumCtx: tt.defaultNumCtx,
|
||||
}
|
||||
|
||||
// Create model options (use float64 as FromMap expects JSON-style numbers)
|
||||
var modelOpts map[string]any
|
||||
if tt.modelNumCtx != 0 {
|
||||
modelOpts = map[string]any{"num_ctx": float64(tt.modelNumCtx)}
|
||||
}
|
||||
model := &Model{
|
||||
Options: modelOpts,
|
||||
}
|
||||
|
||||
// Create request options (use float64 as FromMap expects JSON-style numbers)
|
||||
var requestOpts map[string]any
|
||||
if tt.requestNumCtx != 0 {
|
||||
requestOpts = map[string]any{"num_ctx": float64(tt.requestNumCtx)}
|
||||
}
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("modelOptions failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.NumCtx != tt.expectedNumCtx {
|
||||
t.Errorf("NumCtx = %d, want %d", opts.NumCtx, tt.expectedNumCtx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -552,11 +563,20 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -804,6 +804,7 @@ func (s *mockLlm) GetPort() int { return -
|
||||
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
func (s *mockLlm) ContextLength() int { return 0 }
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
|
||||
@@ -3,13 +3,13 @@ package model
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
@@ -34,7 +35,7 @@ type ModelfileConfig struct {
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
@@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
var parserName, rendererName string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
|
||||
// Check if model supports thinking based on architecture
|
||||
if supportsThinking(opts.ModelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, "", ""),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
@@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
@@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
|
||||
// Check the architecture list
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(typeLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
||||
func getParserName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRendererName returns the renderer name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
||||
func getRendererName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -13,7 +13,11 @@ import (
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Supported quantization types:
|
||||
// - "q4": affine 4-bit, group_size=32 (with qbiases)
|
||||
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
|
||||
// - "q8": affine 8-bit, group_size=64 (with qbiases)
|
||||
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
@@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
case "q4":
|
||||
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
case "nvfp4":
|
||||
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
|
||||
case "q8":
|
||||
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
|
||||
case "mxfp8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
@@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool {
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
return GetTensorQuantization(name, shape, quantize) != ""
|
||||
}
|
||||
|
||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
func normalizeQuantType(quantize string) string {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "Q4", "INT4", "FP4":
|
||||
return "q4"
|
||||
case "Q8", "INT8", "FP8":
|
||||
return "q8"
|
||||
case "NVFP4":
|
||||
return "nvfp4"
|
||||
case "MXFP8":
|
||||
return "mxfp8"
|
||||
default:
|
||||
return quantize
|
||||
}
|
||||
}
|
||||
|
||||
// getQuantGroupSize returns the group size for a given quantization type.
|
||||
// These must match the values used in quantize.go when creating quantized models.
|
||||
func getQuantGroupSize(quantize string) int {
|
||||
switch normalizeQuantType(quantize) {
|
||||
case "nvfp4":
|
||||
return 16
|
||||
case "q4":
|
||||
return 32
|
||||
case "mxfp8":
|
||||
return 32
|
||||
case "q8":
|
||||
return 64
|
||||
default:
|
||||
return 32
|
||||
}
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: q4 (less sensitive)
|
||||
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
// Normalize quantization type to canonical form
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "q8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return true
|
||||
// Skip routing gate weights (should stay high precision)
|
||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Attention MLA weights - keep unquantized (bf16)
|
||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
||||
if strings.Contains(name, "q_a_proj") ||
|
||||
strings.Contains(name, "q_b_proj") ||
|
||||
strings.Contains(name, "kv_a_proj") ||
|
||||
strings.Contains(name, "kv_b_proj") {
|
||||
return "" // No quantization - keep bf16
|
||||
}
|
||||
|
||||
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "q8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (Q4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
strings.Contains(name, "up_proj") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// LM head - use requested quantization
|
||||
if strings.Contains(name, "lm_head") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Default to requested quantization for other weights
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
@@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
if quantize != "" {
|
||||
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
@@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
// Create model_index.json with quantization info if quantizing
|
||||
if quantize != "" {
|
||||
modelIndex := map[string]any{
|
||||
"quantization": strings.ToUpper(quantize),
|
||||
"group_size": getQuantGroupSize(quantize),
|
||||
}
|
||||
indexData, err := json.MarshalIndent(modelIndex, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal model_index.json: %w", err)
|
||||
}
|
||||
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model_index.json layer: %w", err)
|
||||
}
|
||||
layers = append(layers, indexLayer)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
|
||||
@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Group size divisibility tests
|
||||
// FP8/FP4 require divisible by 32
|
||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||
// NVFP4 requires divisible by 16
|
||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
case "", "q4", "q8", "nvfp4", "mxfp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
// MLX requires the last dimension to be divisible by the group size.
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
groupSize := int32(32)
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "NVFP4":
|
||||
groupSize = 16
|
||||
case "Q8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
16
x/imagegen/cache/cache.go
vendored
16
x/imagegen/cache/cache.go
vendored
@@ -9,6 +9,7 @@ type Cache interface {
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
Reset()
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *KVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *RotatingKVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
@@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
|
||||
return filepath.Join(m.BlobDir, blobName)
|
||||
}
|
||||
|
||||
// GetTensorLayers returns all tensor layers for a given component.
|
||||
// Component should be "text_encoder", "transformer", or "vae".
|
||||
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
|
||||
// GetTensorLayers returns tensor layers, optionally filtered by component.
|
||||
// If component is empty, returns all tensor layers (for LLM models).
|
||||
// If component is specified (e.g., "text_encoder", "transformer", "vae"),
|
||||
// returns only layers with that prefix.
|
||||
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
|
||||
prefix := component + "/"
|
||||
var layers []ManifestLayer
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
if component == "" || strings.HasPrefix(layer.Name, component+"/") {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
@@ -206,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
info.Quantization = "Q8"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
|
||||
return Concatenate([]*Array{a, b}, axis)
|
||||
}
|
||||
|
||||
// Stack stacks arrays along a new axis (axis 0 by default)
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
handles := make([]C.mlx_array, len(arrays))
|
||||
for i, arr := range arrays {
|
||||
handles[i] = arr.c
|
||||
}
|
||||
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
|
||||
C.mlx_vector_array_free(vec)
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Slice slices the array
|
||||
func Slice(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
|
||||
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,840 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
MscaleAllDim float32 `json:"mscale_all_dim"`
|
||||
}
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// RoPE scaling
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention with absorption.
|
||||
// This uses absorbed MLA which operates in latent space for reduced KV cache.
|
||||
type MLAAttention struct {
|
||||
// Low-rank query projections
|
||||
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
|
||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
||||
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
|
||||
|
||||
// Low-rank KV projections (with shared rope component)
|
||||
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
|
||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
||||
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
// This operates in latent space for reduced KV cache memory.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
// Split Q into nope and rope parts
|
||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
// KV path: get compressed KV and k_pe
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
// Split into compressed_kv and k_pe (shared rope component)
|
||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
// Apply layernorm to get kv latent representation
|
||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
// Apply RoPE to the rope parts
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
// ABSORBED MLA: project q_nope to latent space
|
||||
// qNope: [B, num_heads, L, qk_nope_head_dim]
|
||||
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// Result: [B, num_heads, L, kv_lora_rank]
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
// Keys = concat(kvLatent, kPE)
|
||||
// kvLatent: [B, 1, L, kv_lora_rank]
|
||||
// kPE: [B, 1, L, qk_rope_head_dim]
|
||||
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
|
||||
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
||||
|
||||
// Cache the smaller latent representation
|
||||
// We cache keys (latent + rope) and use empty values since values are derived from keys
|
||||
cachedL := L
|
||||
if c != nil {
|
||||
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
|
||||
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
|
||||
keys, _ = c.Update(keys, placeholderValues, int(L))
|
||||
cachedL = int32(keys.Shape()[2])
|
||||
}
|
||||
|
||||
// Values are the first kv_lora_rank dims of keys (slice off rope part)
|
||||
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
||||
|
||||
// Queries = concat(qLatent, qPE)
|
||||
// qLatent: [B, num_heads, L, kv_lora_rank]
|
||||
// qPE: [B, num_heads, L, qk_rope_head_dim]
|
||||
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
|
||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
||||
|
||||
// Attention in latent space
|
||||
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
|
||||
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
|
||||
// values: [B, 1, cachedL, kv_lora_rank]
|
||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
// ABSORBED MLA: unembed from latent space
|
||||
// out: [B, num_heads, L, kv_lora_rank]
|
||||
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
|
||||
// Result: [B, num_heads, L, v_head_dim]
|
||||
out = a.UnembedOut.Forward(out)
|
||||
|
||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
// Compute gate logits through linear layer (handles both quantized and non-quantized)
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
// Sigmoid scoring
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
// Add correction bias if present
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
// Group-wise expert selection (simplified for n_group=1)
|
||||
// Select top-k experts
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
shape := inds.Shape()
|
||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
||||
|
||||
// Get scores for selected experts
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
// Normalize if configured
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
// Apply routing scaling factor
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
||||
type SwitchMLP struct {
|
||||
// Dequantized weights (used when GatherQMM not available)
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
|
||||
// Quantized weights (used with GatherQMM for 4/8-bit affine)
|
||||
GateWeightQ, GateScales, GateBiases *mlx.Array
|
||||
UpWeightQ, UpScales, UpBiases *mlx.Array
|
||||
DownWeightQ, DownScales, DownBiases *mlx.Array
|
||||
|
||||
// Quantization bits per projection (supports mixed precision Q4/Q8)
|
||||
GateBits int
|
||||
UpBits int
|
||||
DownBits int
|
||||
|
||||
// Quantization group size per projection (detected from tensor shapes)
|
||||
GateGroupSize int
|
||||
UpGroupSize int
|
||||
DownGroupSize int
|
||||
|
||||
// If true, use GatherQMM with quantized weights
|
||||
UseQuantized bool
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
// Sort for efficient gather (when we have many tokens)
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
// Reorder x based on sorted indices
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
var gate, up, hidden, down *mlx.Array
|
||||
|
||||
if s.UseQuantized {
|
||||
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
|
||||
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
|
||||
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
// Use GatherMM for dequantized/non-quantized weights
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
// Unsort if we sorted
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
// Get expert indices and scores
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
// Apply routed experts
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MoE with residual
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead nn.LinearLayer `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// computeScale computes the attention scale.
|
||||
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
|
||||
func computeScale(cfg *Config) float32 {
|
||||
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
||||
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
||||
scale *= s * s
|
||||
}
|
||||
return scale
|
||||
}
|
||||
|
||||
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
||||
// Currently only 4-bit and 8-bit affine quantization are supported.
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
|
||||
Scales *mlx.Array // Quantization scales (nil if not quantized)
|
||||
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// getQuantParams returns quantization parameters from model metadata.
|
||||
// Returns groupSize, bits, and mode for the model's quantization type.
|
||||
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
|
||||
// Use metadata group_size if available (overrides default)
|
||||
if gs := weights.GroupSize(); gs > 0 {
|
||||
groupSize = gs
|
||||
}
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// loadExpertWeight loads an expert weight.
|
||||
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
|
||||
// Otherwise dequantizes and returns only the weight.
|
||||
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
||||
w, _ := weights.GetTensor(path + ".weight")
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a quantized weight by looking for scales
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
// Get quantization params from metadata
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
|
||||
// Update config with group size (for GatherQMM calls)
|
||||
if cfg.QuantGroupSize == 0 {
|
||||
cfg.QuantGroupSize = groupSize
|
||||
}
|
||||
|
||||
// If GatherQMM is supported and requested, return quantized components
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
// Otherwise dequantize
|
||||
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
||||
}
|
||||
|
||||
return &ExpertWeight{Weight: w}
|
||||
}
|
||||
|
||||
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
||||
// Returns embed_q and unembed_out weights for per-head projections.
|
||||
//
|
||||
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Output:
|
||||
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
|
||||
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
|
||||
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
path := prefix + ".self_attn.kv_b_proj"
|
||||
w, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil || w == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if quantized and dequantize
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
||||
|
||||
// Split into wk and wv
|
||||
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
// wv: [num_heads, v_head_dim, kv_lora_rank]
|
||||
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
||||
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
||||
|
||||
// Transform for absorbed MLA:
|
||||
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
|
||||
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
||||
|
||||
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
|
||||
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
|
||||
unembedOut := wv
|
||||
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// StackedExpertWeights holds stacked weights for all experts.
|
||||
type StackedExpertWeights struct {
|
||||
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
|
||||
Scales *mlx.Array // Stacked scales (nil if not quantized)
|
||||
Biases *mlx.Array // Stacked biases (nil if not quantized)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
||||
func collectAndStackExpertWeights(
|
||||
weights safetensors.WeightSource,
|
||||
prefix string,
|
||||
projName string,
|
||||
numExperts int32,
|
||||
useQuantized bool,
|
||||
cfg *Config,
|
||||
) *StackedExpertWeights {
|
||||
var w, s, b []*mlx.Array
|
||||
var bits, groupSize int
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
||||
ew := loadExpertWeight(weights, path, useQuantized, cfg)
|
||||
if ew == nil {
|
||||
continue
|
||||
}
|
||||
w = append(w, ew.Weight)
|
||||
if ew.Scales != nil {
|
||||
s = append(s, ew.Scales)
|
||||
}
|
||||
if ew.Biases != nil {
|
||||
b = append(b, ew.Biases)
|
||||
}
|
||||
if e == 0 {
|
||||
bits = ew.Bits
|
||||
groupSize = ew.GroupSize
|
||||
}
|
||||
}
|
||||
|
||||
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
||||
if len(w) > 0 {
|
||||
result.Weight = mlx.Stack(w, 0)
|
||||
if len(s) > 0 {
|
||||
result.Scales = mlx.Stack(s, 0)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
result.Biases = mlx.Stack(b, 0)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
||||
// If useQuantized is true and weights support GatherQMM, returns quantized components.
|
||||
// Otherwise returns dequantized weights with nil scales/biases.
|
||||
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
|
||||
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
||||
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
||||
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
|
||||
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
|
||||
return gate, up, down
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(0); err != nil {
|
||||
return nil, fmt.Errorf("load weight data: %w", err)
|
||||
}
|
||||
|
||||
// Set up quantization parameters (only if model is actually quantized)
|
||||
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
|
||||
quantization := weights.Quantization()
|
||||
useQuantized := false
|
||||
if quantization != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
|
||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
// Build tokenizer config with companion files for EOS/BOS token loading
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData, // Already loaded above, contains eos_token_id
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
|
||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights (pass cfg so group sizes can be detected)
|
||||
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
||||
|
||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
||||
if useQuantized {
|
||||
switchMLP.GateWeightQ = gate.Weight
|
||||
switchMLP.GateScales = gate.Scales
|
||||
switchMLP.GateBiases = gate.Biases
|
||||
switchMLP.GateBits = gate.Bits
|
||||
switchMLP.GateGroupSize = gate.GroupSize
|
||||
switchMLP.UpWeightQ = up.Weight
|
||||
switchMLP.UpScales = up.Scales
|
||||
switchMLP.UpBiases = up.Biases
|
||||
switchMLP.UpBits = up.Bits
|
||||
switchMLP.UpGroupSize = up.GroupSize
|
||||
switchMLP.DownWeightQ = down.Weight
|
||||
switchMLP.DownScales = down.Scales
|
||||
switchMLP.DownBiases = down.Biases
|
||||
switchMLP.DownBits = down.Bits
|
||||
switchMLP.DownGroupSize = down.GroupSize
|
||||
} else {
|
||||
switchMLP.GateWeight = gate.Weight
|
||||
switchMLP.UpWeight = up.Weight
|
||||
switchMLP.DownWeight = down.Weight
|
||||
}
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: switchMLP,
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return m.LMHead.Forward(h)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
||||
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
|
||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
||||
// When think is true, the prompt ends with <think> to enable reasoning mode.
|
||||
// When think is false, the prompt ends with </think> to skip reasoning.
|
||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
||||
if think {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
||||
}
|
||||
|
||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
||||
func (m *Model) NewRenderer() *Renderer {
|
||||
return &Renderer{}
|
||||
}
|
||||
|
||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
||||
func (m *Model) NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
@@ -0,0 +1,479 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
parserState_LookingForThinkingOpen parserState = iota
|
||||
parserState_ThinkingStartedEatingWhitespace
|
||||
parserState_CollectingThinking
|
||||
parserState_ThinkingDoneEatingWhitespace
|
||||
parserState_CollectingContent
|
||||
parserState_ToolStartedEatingWhitespace
|
||||
parserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingOpenTag = "<think>"
|
||||
thinkingCloseTag = "</think>"
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type Parser struct {
|
||||
state parserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
||||
func (p *Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
||||
func (p *Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init initializes the parser with tools and thinking configuration.
|
||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type parserEvent interface {
|
||||
isParserEvent()
|
||||
}
|
||||
|
||||
type eventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventContent) isParserEvent() {}
|
||||
|
||||
type eventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (eventRawToolCall) isParserEvent() {}
|
||||
|
||||
type eventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventThinkingContent) isParserEvent() {}
|
||||
|
||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case eventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case eventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case eventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseEvents() []parserEvent {
|
||||
var all []parserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []parserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
||||
var events []parserEvent
|
||||
|
||||
switch p.state {
|
||||
case parserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = parserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case parserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
||||
|
||||
case parserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, eventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
||||
|
||||
case parserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
before, after := p.splitAtTag(toolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, eventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = parserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
||||
|
||||
case parserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, toolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, eventRawToolCall{raw: toolContent})
|
||||
p.state = parserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
||||
func overlap(s, tag string) int {
|
||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
||||
if strings.HasSuffix(s, tag[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
||||
return len(s) - len(trimmed)
|
||||
}
|
||||
|
||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
||||
type ToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeContent(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
escaped := escapeContent(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed ToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
||||
func parseValue(value string, paramType api.PropertyType) any {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// If no type specified, return as string
|
||||
if len(paramType) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Try to parse based on specified types
|
||||
for _, t := range paramType {
|
||||
switch t {
|
||||
case "boolean":
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
case "integer":
|
||||
var i int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
return i
|
||||
}
|
||||
case "number":
|
||||
var f float64
|
||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
||||
return f
|
||||
}
|
||||
case "array", "object":
|
||||
// Try to parse as JSON
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParserThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
thinkEnabled bool
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls int
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled - simple thinking then content",
|
||||
input: "Let me think about this...</think>Here is my answer.",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me think about this...",
|
||||
wantContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - only thinking",
|
||||
input: "I need to consider multiple factors...",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "I need to consider multiple factors...",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - direct content",
|
||||
input: "Here is my direct answer.",
|
||||
thinkEnabled: false,
|
||||
wantThinking: "",
|
||||
wantContent: "Here is my direct answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking with tool call",
|
||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me search for that...",
|
||||
wantContent: "I'll use a tool.",
|
||||
wantToolCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if tt.thinkEnabled {
|
||||
thinkValue = &api.ThinkValue{Value: true}
|
||||
} else {
|
||||
thinkValue = &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
// Define tools for tool call tests
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.Init(tools, nil, thinkValue)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if thinking != tt.wantThinking {
|
||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
||||
}
|
||||
if content != tt.wantContent {
|
||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
||||
}
|
||||
if len(calls) != tt.wantToolCalls {
|
||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserToolCall(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize with thinking disabled
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
p.Init(tools, nil, tv)
|
||||
|
||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
||||
|
||||
_, _, calls, err := p.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
call := calls[0]
|
||||
if call.Function.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := call.Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
||||
}
|
||||
|
||||
unit, ok := call.Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
tag string
|
||||
want int
|
||||
}{
|
||||
{"hello<", "</think>", 1},
|
||||
{"hello</", "</think>", 2},
|
||||
{"hello</t", "</think>", 3},
|
||||
{"hello</th", "</think>", 4},
|
||||
{"hello</thi", "</think>", 5},
|
||||
{"hello</thin", "</think>", 6},
|
||||
{"hello</think", "</think>", 7},
|
||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
||||
{"hello", "</think>", 0},
|
||||
{"", "</think>", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
||||
got := overlap(tt.s, tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{"hello ", 3},
|
||||
{"hello\n\t ", 3},
|
||||
{"hello", 0},
|
||||
{"", 0},
|
||||
{" ", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := trailingWhitespaceLen(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
||||
//
|
||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type Renderer struct{}
|
||||
|
||||
// Render renders messages into the GLM4 chat format.
|
||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
||||
func formatToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRendererSimple(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
// Thinking enabled (default)
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererThinkingDisabled(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
|
||||
result, err := r.Render(messages, nil, tv)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererMultiTurn(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
||||
{Role: "user", Content: "And 3+3?"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check key parts
|
||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
||||
t.Error("missing [gMASK]<sop> prefix")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
||||
t.Error("missing first user message")
|
||||
}
|
||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
||||
t.Error("missing assistant message with thinking")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
||||
t.Error("missing second user message")
|
||||
}
|
||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithSystem(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
||||
t.Error("missing system message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithTools(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check for tool system prompt
|
||||
if !strings.Contains(result, "<|system|>") {
|
||||
t.Error("missing system tag for tools")
|
||||
}
|
||||
if !strings.Contains(result, "# Tools") {
|
||||
t.Error("missing tools header")
|
||||
}
|
||||
if !strings.Contains(result, "<tools>") {
|
||||
t.Error("missing tools tag")
|
||||
}
|
||||
if !strings.Contains(result, "get_weather") {
|
||||
t.Error("missing tool name")
|
||||
}
|
||||
if !strings.Contains(result, "</tools>") {
|
||||
t.Error("missing closing tools tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithToolCalls(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "San Francisco")
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
||||
t.Error("missing tool call")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
||||
t.Error("missing arg_key")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
||||
t.Error("missing arg_value")
|
||||
}
|
||||
if !strings.Contains(result, "</tool_call>") {
|
||||
t.Error("missing tool call closing tag")
|
||||
}
|
||||
if !strings.Contains(result, "<|observation|>") {
|
||||
t.Error("missing observation tag")
|
||||
}
|
||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
||||
t.Error("missing tool response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolJSON(t *testing.T) {
|
||||
input := []byte(`{"name":"test","value":123}`)
|
||||
result := formatToolJSON(input)
|
||||
|
||||
// Should add spaces after : and ,
|
||||
if !strings.Contains(result, ": ") {
|
||||
t.Error("should add space after colon")
|
||||
}
|
||||
if !strings.Contains(result, ", ") {
|
||||
t.Error("should add space after comma")
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
// Note: For modes like "nvfp4", qbiases will be nil.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
// Handle modes that don't return qbiases (e.g., nvfp4)
|
||||
if qbiases != nil {
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qw, scales)
|
||||
}
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
// Supports multiple quantization modes:
|
||||
// - "affine": scale + zero-point bias (QBiases required)
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MultiLinearLayer is an interface for per-head linear layers.
|
||||
// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
|
||||
type MultiLinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// MultiLinear performs per-head linear projections.
|
||||
// Weight shape: [num_heads, output_dims, input_dims]
|
||||
// Input shape: [B, num_heads, L, input_dims]
|
||||
// Output shape: [B, num_heads, L, output_dims]
|
||||
type MultiLinear struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// NewMultiLinear creates a MultiLinear layer with the given weight.
|
||||
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
|
||||
return &MultiLinear{Weight: weight}
|
||||
}
|
||||
|
||||
// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
|
||||
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Weight: [num_heads, output_dims, input_dims]
|
||||
// x: [B, num_heads, L, input_dims]
|
||||
// wT: [num_heads, input_dims, output_dims]
|
||||
// Result: [B, num_heads, L, output_dims]
|
||||
wT := mlx.Transpose(ml.Weight, 0, 2, 1)
|
||||
return mlx.Matmul(x, wT)
|
||||
}
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package runner provides a subprocess server for image generation.
|
||||
// It listens on a port and handles HTTP requests for image generation.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||
// Models that support input images for editing should implement this interface.
|
||||
type ImageEditModel interface {
|
||||
ImageModel
|
||||
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to image model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
err := mlx.InitMLX()
|
||||
if err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down image runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("image runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and decode input images
|
||||
const maxInputImages = 2
|
||||
if len(req.Images) > maxInputImages {
|
||||
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var inputImages []image.Image
|
||||
if len(req.Images) > 0 {
|
||||
// TODO: add memory check for input images
|
||||
|
||||
inputImages = make([]image.Image, len(req.Images))
|
||||
for i, imgBytes := range req.Images {
|
||||
img, err := imagegen.DecodeImage(imgBytes)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
inputImages[i] = img
|
||||
}
|
||||
slog.Info("decoded input images", "count", len(inputImages))
|
||||
|
||||
// Default width/height to first input image dimensions, scaled to max 1024
|
||||
bounds := inputImages[0].Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
if w > 1024 || h > 1024 {
|
||||
if w > h {
|
||||
h = h * 1024 / w
|
||||
w = 1024
|
||||
} else {
|
||||
w = w * 1024 / h
|
||||
h = 1024
|
||||
}
|
||||
}
|
||||
req.Width = int32(w)
|
||||
req.Height = int32(h)
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using the common interface
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||
var img *mlx.Array
|
||||
var err error
|
||||
if len(inputImages) > 0 {
|
||||
editModel, ok := s.model.(ImageEditModel)
|
||||
if !ok {
|
||||
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||
} else {
|
||||
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -17,17 +17,31 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
|
||||
GroupSize() int // Returns quantization group size, or 0 if not specified
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
// QuantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// MLX quantization modes:
|
||||
// - "affine": scale + zero-point bias, group_size=32/64/128
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
|
||||
// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
case "NVFP4":
|
||||
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
// 4-bit quantization with affine mode (scale + qbias)
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
// 8-bit quantization with affine mode (default for quantized models)
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
return 32, 8, "affine" // Default to affine
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
|
||||
if field.Type == linearLayerType {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nn.MultiLinearLayer interface fields specially
|
||||
multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
|
||||
if field.Type == multiLinearLayerType {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
layer, err := LoadMultiLinearLayer(weights, fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath+": "+err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(layer))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
|
||||
// LoadMultiLinearLayer loads a per-head linear layer from weights.
|
||||
// Weight shape should be [num_heads, output_dims, input_dims].
|
||||
// If quantized, always dequantizes since batched quantized matmul isn't supported.
|
||||
func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
hasScale := weights.HasTensor(scalePath)
|
||||
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
if hasScale {
|
||||
scales, err := weights.GetTensor(scalePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||
}
|
||||
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
// Always dequantize for MultiLinear - no batched quantized matmul support
|
||||
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
|
||||
weightShape := weight.Shape()
|
||||
scalesShape := scales.Shape()
|
||||
weightCols := int(weightShape[len(weightShape)-1])
|
||||
scalesCols := int(scalesShape[len(scalesShape)-1])
|
||||
|
||||
// Detect quantization from tensor shapes
|
||||
// groupSize = weightCols * packFactor / scalesCols
|
||||
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
var bits, groupSize int
|
||||
// Use metadata to help disambiguate when shapes are ambiguous
|
||||
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
|
||||
quantType := strings.ToUpper(weights.Quantization())
|
||||
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
|
||||
|
||||
if groupSize4 == 32 {
|
||||
// Unambiguous: Q4 with group_size=32
|
||||
bits = 4
|
||||
groupSize = 32
|
||||
} else if groupSize8 == 64 {
|
||||
// Unambiguous: Q8 with group_size=64
|
||||
bits = 8
|
||||
groupSize = 64
|
||||
} else if groupSize4 == 64 && groupSize8 == 32 {
|
||||
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
|
||||
if isQ8Type {
|
||||
bits = 8
|
||||
groupSize = 32
|
||||
} else {
|
||||
bits = 4
|
||||
groupSize = 64
|
||||
}
|
||||
} else {
|
||||
// Fallback: use global quantization params
|
||||
_, bits, _ = QuantizationParams(weights.Quantization())
|
||||
packFactor := 32 / bits
|
||||
groupSize = weightCols * packFactor / scalesCols
|
||||
}
|
||||
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
|
||||
}
|
||||
|
||||
return nn.NewMultiLinear(weight), nil
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
hasScale := weights.HasTensor(scalePath)
|
||||
if hasScale {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
|
||||
weightShape := weight.Shape()
|
||||
scalesShape := scales.Shape()
|
||||
weightCols := int(weightShape[len(weightShape)-1])
|
||||
scalesCols := int(scalesShape[len(scalesShape)-1])
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
// Detect quantization from tensor shapes
|
||||
// groupSize = weightCols * packFactor / scalesCols
|
||||
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
var bits, groupSize int
|
||||
mode := "affine"
|
||||
// Use metadata to help disambiguate when shapes are ambiguous
|
||||
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
|
||||
quantType := strings.ToUpper(weights.Quantization())
|
||||
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
|
||||
|
||||
if groupSize4 == 32 {
|
||||
// Unambiguous: Q4 with group_size=32
|
||||
bits = 4
|
||||
groupSize = 32
|
||||
} else if groupSize8 == 64 {
|
||||
// Unambiguous: Q8 with group_size=64
|
||||
bits = 8
|
||||
groupSize = 64
|
||||
} else if groupSize4 == 64 && groupSize8 == 32 {
|
||||
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
|
||||
if isQ8Type {
|
||||
bits = 8
|
||||
groupSize = 32
|
||||
} else {
|
||||
bits = 4
|
||||
groupSize = 64
|
||||
}
|
||||
} else {
|
||||
// Fallback: use global quantization params
|
||||
_, bits, mode = QuantizationParams(weights.Quantization())
|
||||
packFactor := 32 / bits
|
||||
groupSize = weightCols * packFactor / scalesCols
|
||||
}
|
||||
|
||||
// NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
|
||||
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
|
||||
if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user