mirror of
https://github.com/ollama/ollama.git
synced 2026-02-04 12:42:58 -05:00
Compare commits
11 Commits
v0.15.5-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cefabd79a8 | ||
|
|
df70249520 | ||
|
|
77eb2ca619 | ||
|
|
ee25219edd | ||
|
|
b1fccabb34 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 |
3
anthropic/anthropic.go
Normal file → Executable file
3
anthropic/anthropic.go
Normal file → Executable file
@@ -211,6 +211,7 @@ type MessageDelta struct {
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -721,6 +722,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
})
|
||||
}
|
||||
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
@@ -732,6 +734,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
|
||||
20
anthropic/anthropic_test.go
Normal file → Executable file
20
anthropic/anthropic_test.go
Normal file → Executable file
@@ -642,7 +642,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
@@ -650,6 +650,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_delta" {
|
||||
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||
if data.Type != "message_delta" {
|
||||
t.Errorf("unexpected data type: %+v", data)
|
||||
}
|
||||
|
||||
if data.Delta.StopReason != "end_turn" {
|
||||
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||
}
|
||||
|
||||
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("unexpected data: %+v", e.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -52,7 +53,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -663,6 +664,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
_ = browser.OpenURL(aErr.SigninURL)
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -15,11 +15,13 @@ type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
return nil
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
@@ -41,13 +43,13 @@ func (c *Claude) findPath() (string, error) {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -84,17 +84,21 @@ func TestClaudeArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,20 +14,21 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -20,6 +21,14 @@ type config struct {
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,6 +36,46 @@ func configPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
slog.Info("migrated config", "from", oldPath, "to", newPath)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
@@ -34,6 +83,11 @@ func load() (*config, error) {
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
|
||||
@@ -200,12 +200,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
@@ -267,7 +265,7 @@ func TestConfigPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
@@ -322,6 +320,183 @@ func TestLoad(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -39,7 +39,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func (d *Droid) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
Run(model string, args []string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
@@ -49,6 +50,15 @@ var integrations = map[string]Runner{
|
||||
"openclaw": &Openclaw{},
|
||||
}
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
|
||||
}
|
||||
|
||||
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
@@ -94,62 +104,25 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
@@ -163,7 +136,27 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
if len(toPull) > 0 {
|
||||
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errCancelled
|
||||
}
|
||||
for _, m := range toPull {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, m); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
@@ -233,13 +226,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
@@ -248,7 +241,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
@@ -263,14 +256,37 @@ Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
// No "--" separator: only allow 0 or 1 args (integration name)
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
// "--" was used: args before it = integration name, args after = passthrough
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
@@ -286,16 +302,14 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
@@ -350,13 +364,13 @@ Examples:
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -364,3 +378,154 @@ Examples:
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModel(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Non-existing models get "install?" suffix and are pushed to the bottom.
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
if items[i].Description != "" {
|
||||
items[i].Description += ", install?"
|
||||
} else {
|
||||
items[i].Description = "install?"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if !ac && !bc && aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -90,8 +92,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -121,7 +123,7 @@ func TestLaunchCmd(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -174,15 +176,336 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
var _ func(string, []string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseArgs(t *testing.T) {
|
||||
// Tests reflect cobra's ArgsLenAtDash() semantics:
|
||||
// - cobra strips "--" from args
|
||||
// - ArgsLenAtDash() returns the index where "--" was, or -1
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string // args as cobra delivers them (no "--")
|
||||
dashIdx int // what ArgsLenAtDash() returns
|
||||
wantName string
|
||||
wantArgs []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no extra args, no dash",
|
||||
args: []string{"claude"},
|
||||
dashIdx: -1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "with extra args after --",
|
||||
args: []string{"codex", "-p", "myprofile"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"-p", "myprofile"},
|
||||
},
|
||||
{
|
||||
name: "extra args only after --",
|
||||
args: []string{"codex", "--sandbox", "workspace-write"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"--sandbox", "workspace-write"},
|
||||
},
|
||||
{
|
||||
name: "-- at end with no args after",
|
||||
args: []string{"claude"},
|
||||
dashIdx: 1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "-- with no integration name",
|
||||
args: []string{"--verbose"},
|
||||
dashIdx: 0,
|
||||
wantName: "",
|
||||
wantArgs: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple args before -- is error",
|
||||
args: []string{"claude", "codex", "--verbose"},
|
||||
dashIdx: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple args without -- is error",
|
||||
args: []string{"claude", "codex"},
|
||||
dashIdx: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no args, no dash",
|
||||
args: []string{},
|
||||
dashIdx: -1,
|
||||
wantName: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the parsing logic from LaunchCmd using dashIdx
|
||||
var name string
|
||||
var parsedArgs []string
|
||||
var err error
|
||||
|
||||
dashIdx := tt.dashIdx
|
||||
args := tt.args
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
err = fmt.Errorf("unexpected arguments: %v", args[1:])
|
||||
} else if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
} else {
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
parsedArgs = args[dashIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if name != tt.wantName {
|
||||
t.Errorf("name = %q, want %q", name, tt.wantName)
|
||||
}
|
||||
if !slices.Equal(parsedArgs, tt.wantArgs) {
|
||||
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
t.Errorf("pre-checked model should be first, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "glm-4.7:cloud":
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "kimi-k2.5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_LatestTagStripped(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash:latest", Remote: false},
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// :latest should be stripped from display names
|
||||
for _, name := range got {
|
||||
if strings.HasSuffix(name, ":latest") {
|
||||
t.Errorf("name %q should not have :latest suffix", name)
|
||||
}
|
||||
}
|
||||
|
||||
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
|
||||
count := 0
|
||||
for _, name := range got {
|
||||
if name == "glm-4.7-flash" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
|
||||
}
|
||||
|
||||
// Stripped name should be in existingModels so it won't be pulled
|
||||
if !existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should be in existingModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if !existingModels["llama3.2"] {
|
||||
t.Error("llama3.2 should be in existingModels")
|
||||
}
|
||||
if !existingModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in existingModels")
|
||||
}
|
||||
if existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
|
||||
}
|
||||
|
||||
if !cloudModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in cloudModels")
|
||||
}
|
||||
if !cloudModels["kimi-k2.5:cloud"] {
|
||||
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
|
||||
}
|
||||
if cloudModels["llama3.2"] {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string) error {
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
@@ -52,7 +52,7 @@ func (c *Openclaw) Run(model string) error {
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, "gateway")
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
|
||||
@@ -18,7 +18,7 @@ type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (o *OpenCode) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -353,10 +353,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
desc := ""
|
||||
if item.Description != "" {
|
||||
desc = " " + ansiGray + "- " + item.Description + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
@@ -317,6 +317,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
512
convert/convert_qwen3next.go
Normal file
512
convert/convert_qwen3next.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
|
||||
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
|
||||
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen3next: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
}
|
||||
kv["attention.key_length"] = headDim
|
||||
kv["attention.value_length"] = headDim
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
if q.SharedExpertIntermSize > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
continue
|
||||
}
|
||||
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
headVDim int
|
||||
numKHeads int
|
||||
numVHeads int
|
||||
qkvzDim int
|
||||
qkvOut int
|
||||
gateOut int
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
|
||||
if len(shape) != 2 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
numKHeads := int(q.LinearNumKeyHeads)
|
||||
numVHeads := int(q.LinearNumValueHeads)
|
||||
headKDim := int(q.LinearKeyHeadDim)
|
||||
headVDim := int(q.LinearValueHeadDim)
|
||||
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
if numVHeads%numKHeads != 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
hidden := int(shape[1])
|
||||
vPerHead := headVDim * (numVHeads / numKHeads)
|
||||
qkvzDim := 2*headKDim + 2*vPerHead
|
||||
expectedOut := qkvzDim * numKHeads
|
||||
if int(shape[0]) != expectedOut {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
return qkvzSplitSpec{
|
||||
hidden: hidden,
|
||||
headKDim: headKDim,
|
||||
headVDim: headVDim,
|
||||
numKHeads: numKHeads,
|
||||
numVHeads: numVHeads,
|
||||
qkvzDim: qkvzDim,
|
||||
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
|
||||
gateOut: headVDim * numVHeads,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
|
||||
spec, ok := q.qkvzSpec(t.Shape())
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qkvTensor := t.Clone()
|
||||
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
|
||||
|
||||
gateTensor := t.Clone()
|
||||
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
|
||||
|
||||
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
|
||||
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
|
||||
|
||||
return &ggml.Tensor{
|
||||
Name: qkvName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
|
||||
WriterTo: qkvTensor,
|
||||
}, &ggml.Tensor{
|
||||
Name: gateName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
|
||||
WriterTo: gateTensor,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
|
||||
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := 0
|
||||
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += vPerHead
|
||||
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
|
||||
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
|
||||
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
|
||||
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
|
||||
|
||||
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out tensor.Tensor
|
||||
if extractGate {
|
||||
out = zMat
|
||||
} else {
|
||||
out, err = tensor.Concat(1, qMat, kMat, vMat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
out, err = tensor.Transpose(out, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(out.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
|
||||
n, err := n.Add(ones)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
|
||||
// Full attention (self_attn)
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
"linear_attn.A_log", "ssm_a",
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
@@ -100,6 +100,7 @@ func (st safetensor) Kind() uint32 {
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
!strings.HasPrefix(st.name, "mm.") &&
|
||||
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -268,6 +268,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
@@ -866,6 +867,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
|
||||
3
go.sum
3
go.sum
@@ -174,6 +174,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -304,6 +306,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
|
||||
@@ -75,3 +75,10 @@ type Cache interface {
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
|
||||
// CheckpointCache optionally supports restoring recurrent state to a prior
|
||||
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
|
||||
// The returned position is the number of tokens that can be kept (prefix length).
|
||||
type CheckpointCache interface {
|
||||
PrepareRestore(seq int, targetPos int32) (int32, bool)
|
||||
}
|
||||
|
||||
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
@@ -0,0 +1,276 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeffrey Morgan <jmorganca@gmail.com>
|
||||
Date: Tue, 3 Feb 2026 12:00:00 -0800
|
||||
Subject: [PATCH] ggml: metal solve_tri
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
|
||||
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
|
||||
7 files changed, 177 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 680904d13..83385c9ef 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
+ assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
+
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
+
|
||||
+ char base[256];
|
||||
+ char name[256];
|
||||
+
|
||||
+ snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
+ snprintf(name, 256, "%s", base);
|
||||
+
|
||||
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
+ if (!res.pipeline) {
|
||||
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
+ }
|
||||
+
|
||||
+ return res;
|
||||
+}
|
||||
+
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index 0a8b9211a..8a9d17460 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index 7b5ee968c..4e5acfbe5 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ return ggml_is_contiguous(op->src[0]) &&
|
||||
+ ggml_is_contiguous(op->src[1]) &&
|
||||
+ op->src[0]->type == GGML_TYPE_F32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_F32 &&
|
||||
+ op->type == GGML_TYPE_F32;
|
||||
+ case GGML_OP_COUNT_EQUAL:
|
||||
+ return has_simdgroup_reduction &&
|
||||
+ op->src[0]->type == GGML_TYPE_I32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_I32 &&
|
||||
+ op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
index 8944b07e9..cfdea9c07 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
+typedef struct {
|
||||
+ int32_t ne00;
|
||||
+ int32_t ne01;
|
||||
+ int32_t ne02;
|
||||
+ int32_t ne03;
|
||||
+ uint64_t nb00;
|
||||
+ uint64_t nb01;
|
||||
+ uint64_t nb02;
|
||||
+ uint64_t nb03;
|
||||
+ int32_t ne10;
|
||||
+ int32_t ne11;
|
||||
+ uint64_t nb10;
|
||||
+ uint64_t nb11;
|
||||
+ uint64_t nb12;
|
||||
+ uint64_t nb13;
|
||||
+ uint64_t nb0;
|
||||
+ uint64_t nb1;
|
||||
+ uint64_t nb2;
|
||||
+ uint64_t nb3;
|
||||
+} ggml_metal_kargs_solve_tri;
|
||||
+
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 80864f303..4ac135603 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ {
|
||||
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
+ } break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
+ ggml_tensor * op = ctx->node(idx);
|
||||
+
|
||||
+ ggml_metal_library_t lib = ctx->lib;
|
||||
+ ggml_metal_encoder_t enc = ctx->enc;
|
||||
+
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
+
|
||||
+ ggml_metal_kargs_solve_tri args = {
|
||||
+ /*.ne00 =*/ ne00,
|
||||
+ /*.ne01 =*/ ne01,
|
||||
+ /*.ne02 =*/ ne02,
|
||||
+ /*.ne03 =*/ ne03,
|
||||
+ /*.nb00 =*/ nb00,
|
||||
+ /*.nb01 =*/ nb01,
|
||||
+ /*.nb02 =*/ nb02,
|
||||
+ /*.nb03 =*/ nb03,
|
||||
+ /*.ne10 =*/ ne10,
|
||||
+ /*.ne11 =*/ ne11,
|
||||
+ /*.nb10 =*/ nb10,
|
||||
+ /*.nb11 =*/ nb11,
|
||||
+ /*.nb12 =*/ nb12,
|
||||
+ /*.nb13 =*/ nb13,
|
||||
+ /*.nb0 =*/ nb0,
|
||||
+ /*.nb1 =*/ nb1,
|
||||
+ /*.nb2 =*/ nb2,
|
||||
+ /*.nb3 =*/ nb3,
|
||||
+ };
|
||||
+
|
||||
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
+
|
||||
+ const int64_t ncols = ne10;
|
||||
+ const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
+ const int64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ int nth = 64;
|
||||
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
+ if (nth < 1) {
|
||||
+ nth = 1;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
+
|
||||
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
+
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
+
|
||||
+ return 1;
|
||||
+}
|
||||
+
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
index 902b54452..a475183d3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d33c16079..c37447a10 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
+kernel void kernel_solve_tri_f32(
|
||||
+ constant ggml_metal_kargs_solve_tri & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device char * dst,
|
||||
+ uint tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort ntg[[threads_per_threadgroup]]) {
|
||||
+ const uint64_t ncols = (uint64_t) args.ne10;
|
||||
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
+ const uint64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
+ if (gid >= nr) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
+ const uint64_t i02 = rem / ncols;
|
||||
+ const uint64_t i01 = rem - i02 * ncols;
|
||||
+
|
||||
+ const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
+ const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
+ const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
+ const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
+ const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
+ const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
+ const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
+ const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
+ const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
+ const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
+
|
||||
+ device const float * A = (device const float *) src0;
|
||||
+ device const float * B = (device const float *) src1;
|
||||
+ device float * X = (device float *) dst;
|
||||
+
|
||||
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
+
|
||||
+ const uint64_t n = (uint64_t) args.ne11;
|
||||
+
|
||||
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
+ float sum = 0.0f;
|
||||
+ for (uint64_t t = 0; t < i00; ++t) {
|
||||
+ sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
+ X[X_base + t * sx1 + i01 * sx0];
|
||||
+ }
|
||||
+
|
||||
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
+ X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -207,6 +207,32 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Exp(ctx Context) Tensor
|
||||
Neg(ctx Context) Tensor
|
||||
|
||||
// Clamp clamps values to [min, max] range
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
|
||||
// Softplus computes ln(1 + exp(x))
|
||||
Softplus(ctx Context) Tensor
|
||||
|
||||
// CumSum computes cumulative sum along dimension 0
|
||||
CumSum(ctx Context) Tensor
|
||||
|
||||
// Diag creates a diagonal matrix from a 1D tensor
|
||||
Diag(ctx Context) Tensor
|
||||
|
||||
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
|
||||
Tri(ctx Context, triType int) Tensor
|
||||
|
||||
// Fill fills a tensor with a constant value (in-place)
|
||||
Fill(ctx Context, value float32) Tensor
|
||||
|
||||
// Repeat4D repeats tensor to match target shape
|
||||
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
|
||||
|
||||
// SolveTri solves a triangular system Ax = B
|
||||
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
|
||||
|
||||
sched := C.ggml_backend_sched_new_ext(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
@@ -1779,6 +1779,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
op->src[1]->type == GGML_TYPE_I32 &&
|
||||
op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
|
||||
@@ -2385,6 +2385,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_solve_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
|
||||
const int64_t ncols = ne10;
|
||||
const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
const int64_t nr = n_batches * ncols;
|
||||
|
||||
int nth = 64;
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
if (nth < 1) {
|
||||
nth = 1;
|
||||
}
|
||||
|
||||
const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -56,6 +56,18 @@ type fakeTensor struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// Stub methods to satisfy ml.Tensor interface
|
||||
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
return &fakeTensor{Name: name}
|
||||
|
||||
@@ -20,5 +20,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3next"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
103
model/models/qwen3next/attention.go
Normal file
103
model/models/qwen3next/attention.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the attention layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
|
||||
|
||||
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
|
||||
// Key differences from standard attention:
|
||||
// - Q projection outputs 2x size (Q + gate interleaved)
|
||||
// - Both Q and K have RMSNorm
|
||||
// - Output is gated: attn * sigmoid(gate)
|
||||
type FullAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
// Use Dim() instead of Shape() for consistent behavior during graph construction
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
|
||||
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 0 {
|
||||
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
||||
if batchSize != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||
batchSize = seqTokens * seqs
|
||||
} else if batchSize != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
}
|
||||
headDim := opts.headDim()
|
||||
numHeads := opts.numHeads
|
||||
|
||||
// Q projection outputs query + gate interleaved
|
||||
qFull := sa.Query.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape to [headDim * 2, numHeads, batchSize]
|
||||
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
|
||||
|
||||
// Split Q and gate along dimension 0
|
||||
// Q: first headDim elements, gate: second headDim elements
|
||||
query := qFull.Slice(ctx, 0, 0, headDim, 1)
|
||||
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
|
||||
|
||||
// Make query contiguous for further operations
|
||||
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K and V projections
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Derive numKVHeads from tensor dimensions (per-layer value)
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Apply QK normalization
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// Apply RoPE
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
// Standard attention computation
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Apply sigmoid gate
|
||||
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
|
||||
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
|
||||
gateSigmoid := gate.Sigmoid(ctx)
|
||||
attention = attention.Mul(ctx, gateSigmoid)
|
||||
|
||||
return sa.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
596
model/models/qwen3next/cache.go
Normal file
596
model/models/qwen3next/cache.go
Normal file
@@ -0,0 +1,596 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for full attention layers
|
||||
// - per-sequence conv state for linear attention layers
|
||||
// - per-sequence delta state for linear attention layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int // convKernelSize - 1
|
||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
||||
|
||||
// Delta state dimensions
|
||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer delta state buffers (allocated lazily)
|
||||
deltaCtxs map[int]ml.Context
|
||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointDeltaCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(
|
||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||
convDim, convChannels, deltaStateSize int,
|
||||
) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
convDim: convDim,
|
||||
convChannels: convChannels,
|
||||
deltaStateSize: deltaStateSize,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
deltaCtxs: make(map[int]ml.Context),
|
||||
deltaStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: checkpointCountDefault,
|
||||
checkpointMinPos: checkpointMinPosDefault,
|
||||
checkpointInterval: checkpointIntervalDefault,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.deltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointDeltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
c.reserveCheckpoints = true
|
||||
c.planCheckpoints(batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero state for newly allocated slots
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = false
|
||||
c.planCheckpoints(batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Zero conv states
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// Zero delta states
|
||||
if len(c.deltaStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
||||
for _, buf := range c.deltaStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.deltaStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// Copy-on-write for recurrent state
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
// If the recurrent slot is shared, detach it before applying a restore.
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
// Removal invalidates recurrent state
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.deltaStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent delta state must stay in F32.
|
||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
||||
c.deltaStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
498
model/models/qwen3next/checkpoints.go
Normal file
498
model/models/qwen3next/checkpoints.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
checkpointCountDefault = 32
|
||||
checkpointMinPosDefault = int32(16)
|
||||
checkpointIntervalDefault = int32(1280)
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
delta map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.delta {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.deltaStates {
|
||||
if entry.delta == nil || entry.delta[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.delta != nil {
|
||||
if dstEntry.delta == nil {
|
||||
dstEntry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.delta {
|
||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.delta == nil {
|
||||
entry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.delta[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointDeltaCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
||||
entry.delta[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointDelta(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
300
model/models/qwen3next/checkpoints_test.go
Normal file
300
model/models/qwen3next/checkpoints_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestBackend(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
_ = f.Close()
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 1, 1)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
||||
backend := newTestBackend(t)
|
||||
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.slotForSeq[2] = 0
|
||||
cache.refCount[0] = 2
|
||||
cache.refCount[1] = 0
|
||||
cache.freeSlots = []int{1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[1] != 1 {
|
||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[2] != 0 {
|
||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
||||
}
|
||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
||||
}
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
// Only set conv checkpoint, not delta - making it incomplete
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.delta is not set, so checkpoint is incomplete
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Don't set convStates/deltaStates - with no layers to check,
|
||||
// entryComplete will return true as long as entry.pos >= 0
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
// Test that restoreComplete returns true when no layers need checkpoints
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
// Fill the buffer
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Create fake tensor data in the first entry's maps
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
||||
|
||||
// Record another entry, which should wrap around and overwrite entry 0
|
||||
store.record(40)
|
||||
|
||||
// Verify the maps are still present (we reuse tensors)
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].delta == nil {
|
||||
t.Fatalf("expected delta map to be preserved on reuse")
|
||||
}
|
||||
|
||||
// Verify the new position was recorded
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
// Test behavior when buffer is exactly at capacity
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
// Verify both checkpoints are accessible
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
// Test behavior with zero-size buffer
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
// Test pruning that removes all checkpoints
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Prune everything by setting threshold below all positions
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
// When all checkpoints are pruned, lastPos is reset to -1
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
473
model/models/qwen3next/deltanet.go
Normal file
473
model/models/qwen3next/deltanet.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const chunkSize = 64
|
||||
|
||||
// TriType constants for triangular matrix operations
|
||||
const (
|
||||
TriTypeUpperDiag = 0
|
||||
TriTypeUpper = 1
|
||||
TriTypeLowerDiag = 2
|
||||
TriTypeLower = 3
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Masks holds pre-computed mask tensors for chunked attention
|
||||
type Masks struct {
|
||||
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
|
||||
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
|
||||
Diag ml.Tensor // causal + identity
|
||||
}
|
||||
|
||||
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
|
||||
// It implements the Operator interface directly.
|
||||
type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
// Layer index for cache access (set during model construction)
|
||||
Layer int
|
||||
}
|
||||
|
||||
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
||||
func createMasks(ctx ml.Context) *Masks {
|
||||
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
||||
ones = ones.Fill(ctx, 1.0)
|
||||
causalMask := ones.Tri(ctx, TriTypeLower)
|
||||
|
||||
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
|
||||
onesVec = onesVec.Fill(ctx, 1.0)
|
||||
identity := onesVec.Diag(ctx)
|
||||
|
||||
diagMask := causalMask.Add(ctx, identity)
|
||||
|
||||
return &Masks{
|
||||
Causal: causalMask,
|
||||
Identity: identity,
|
||||
Diag: diagMask,
|
||||
}
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := gdn.Layer
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2)
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 1 {
|
||||
if nSeqTokens != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
} else {
|
||||
if nSeqTokens != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
||||
nSeqTokens = seqTokens
|
||||
nSeqs = seqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headKDim := opts.ssmDState
|
||||
numKHeads := opts.ssmNGroup
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Split beta and alpha
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Reshape to merge head dimensions
|
||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
// Log this - if it happens, short-term context will be lost
|
||||
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states
|
||||
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
|
||||
|
||||
// Concatenate with input for convolution
|
||||
convInput := convStates.Concat(ctx, qkvMixed, 0)
|
||||
|
||||
// Save new conv state (last convKernelSize-1 tokens)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution (kernel must be F32 for Metal)
|
||||
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
||||
convOutput = convOutput.SILU(ctx)
|
||||
|
||||
// Reshape for extraction
|
||||
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
|
||||
|
||||
// Extract convolved Q, K, V
|
||||
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
|
||||
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
|
||||
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
|
||||
|
||||
// Reshape to 4D
|
||||
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Get delta state from cache
|
||||
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||
if err != nil {
|
||||
// Log this - if it happens frequently, context will degrade
|
||||
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
|
||||
}
|
||||
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
|
||||
|
||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||
if numKHeads != numVHeads {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
var attnOut ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
}
|
||||
|
||||
// Apply gated normalization
|
||||
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
|
||||
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
|
||||
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
|
||||
zSilu := z2D.SILU(ctx)
|
||||
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
|
||||
|
||||
// Reshape for output projection
|
||||
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
out := gdn.SSMOut.Forward(ctx, finalOutput)
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
|
||||
// deltaNetAutoregressive implements single-token state update.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Reshape gate and beta for broadcasting
|
||||
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
|
||||
// Apply exponential to gate
|
||||
gT = gT.Exp(ctx)
|
||||
|
||||
// state = state * g_t
|
||||
state = state.Mul(ctx, gT)
|
||||
|
||||
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
kvMem := state.Mul(ctx, kTUnsqueezed)
|
||||
// Sum over dim=-2 (second dimension after permute)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kvMem = kvMem.SumRows(ctx)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// v_t with singleton dimension
|
||||
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
|
||||
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
vDiff := vT.Sub(ctx, kvMem)
|
||||
delta := vDiff.Mul(ctx, betaT)
|
||||
|
||||
// state = state + k_t.unsqueeze(-1) * delta
|
||||
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
|
||||
state = state.Add(ctx, kTDelta)
|
||||
|
||||
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
stateQ := state.Mul(ctx, qTUnsqueezed)
|
||||
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
coreAttnOut := stateQ.SumRows(ctx)
|
||||
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
||||
}
|
||||
|
||||
// deltaNetChunked implements chunked computation for prefill.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
masks *Masks,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
headKDim := q.Dim(0)
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nTokens := q.Dim(2)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Permute tensors for chunked computation
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
|
||||
|
||||
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Compute padding
|
||||
pad := (chunkSize - nTokens%chunkSize) % chunkSize
|
||||
nChunks := (nTokens + pad) / chunkSize
|
||||
|
||||
// Pad tensors
|
||||
if pad > 0 {
|
||||
q = q.Pad(ctx, 0, pad, 0, 0)
|
||||
k = k.Pad(ctx, 0, pad, 0, 0)
|
||||
v = v.Pad(ctx, 0, pad, 0, 0)
|
||||
gate = gate.Pad(ctx, pad, 0, 0, 0)
|
||||
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
||||
}
|
||||
|
||||
// Use pre-computed masks (passed in, not recreated)
|
||||
causalMask := masks.Causal
|
||||
identity := masks.Identity
|
||||
diagMask := masks.Diag
|
||||
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
|
||||
|
||||
// v_beta = v * beta, k_beta = k * beta
|
||||
vBeta := v.Mul(ctx, beta)
|
||||
kBeta := k.Mul(ctx, beta)
|
||||
|
||||
// Reshape for chunked computation
|
||||
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
|
||||
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
|
||||
// g_cumsum = cumsum(gate)
|
||||
gCumsum := gate.CumSum(ctx)
|
||||
|
||||
// Compute decay mask
|
||||
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
decayMask := gcsBroadcast.Sub(ctx, gcsI)
|
||||
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
decayMask = decayMask.Exp(ctx)
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
|
||||
// k @ k_beta^T
|
||||
kMulKBeta := k.Mulmat(ctx, kBeta)
|
||||
|
||||
// k_decay = k @ k_beta^T * decay_mask
|
||||
kDecay := kMulKBeta.Mul(ctx, decayMask)
|
||||
|
||||
// attn = -k_decay * causal_mask
|
||||
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
|
||||
|
||||
// Triangular solve: (I - attn_lower)^-1 @ attn
|
||||
attnLower := attn.Mul(ctx, causalMask)
|
||||
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
|
||||
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
|
||||
attn = linSolve.Mul(ctx, causalMask)
|
||||
attn = attn.Add(ctx, identity4D)
|
||||
|
||||
// v = v_beta^T @ attn
|
||||
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
v = vBetaT.Mulmat(ctx, attn)
|
||||
|
||||
// Compute g_exp for state update
|
||||
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
gExp := gCumsumT.Exp(ctx)
|
||||
|
||||
// kbeta_gexp = k_beta * g_exp
|
||||
kBetaGExp := kBeta.Mul(ctx, gExp)
|
||||
|
||||
// k_cumdecay = attn @ kbeta_gexp^T
|
||||
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
|
||||
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
|
||||
attnKQ := k.Mulmat(ctx, q)
|
||||
attnKQ = attnKQ.Mul(ctx, decayMask)
|
||||
attnKQ = attnKQ.Mul(ctx, diagMask)
|
||||
|
||||
// Pre-compute g_last and key_gdiff
|
||||
// g_last = view of last element in g_cumsum along chunk_size dimension
|
||||
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
|
||||
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
|
||||
gLastExp := gLast.Exp(ctx)
|
||||
|
||||
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
|
||||
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
||||
gDiffExp := gDiff.Exp(ctx)
|
||||
|
||||
// key_gdiff = k * exp(g_diff)
|
||||
keyGDiff := k.Mul(ctx, gDiffExp)
|
||||
|
||||
// Process chunks and update state
|
||||
var coreAttnOut ml.Tensor
|
||||
newState := state
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
||||
|
||||
// state^T - permute is needed but Contiguous creates a copy
|
||||
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// v_prime = k_cumdecay @ state
|
||||
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
||||
|
||||
// v_new = v - v_prime
|
||||
vNew := vChunk.Sub(ctx, vPrime)
|
||||
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// attn_inter = (q * g_exp) @ state
|
||||
qGExp := qChunk.Mul(ctx, gExpChunk)
|
||||
attnInter := stateT.Mulmat(ctx, qGExp)
|
||||
|
||||
// core_attn_out = attn_inter + attn @ v_new
|
||||
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
if coreAttnOut == nil {
|
||||
coreAttnOut = coreAttnOutChunk
|
||||
} else {
|
||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||
}
|
||||
|
||||
// Update state for next chunk using pre-computed values
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
|
||||
// kgdmulvnew = key_gdiff^T @ v_new
|
||||
kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||
|
||||
// state = state * g_last + kgdmulvnew
|
||||
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
newState = newState.Mul(ctx, gExpLastReshaped)
|
||||
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
||||
}
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
// Slice to remove padding
|
||||
if pad > 0 {
|
||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||
}
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
||||
}
|
||||
383
model/models/qwen3next/model.go
Normal file
383
model/models/qwen3next/model.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
keyLength int
|
||||
valueLength int
|
||||
ropeDim int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeScale float32
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
normTopKProb bool
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
ssmDInner int // d_inner = head_v_dim * num_v_heads
|
||||
ssmDState int // head_k_dim
|
||||
ssmNGroup int // num_k_heads
|
||||
ssmDtRank int // num_v_heads
|
||||
convKernelSize int // SSM conv kernel size
|
||||
|
||||
// Per-layer type from GGUF metadata
|
||||
isRecurrent []bool
|
||||
|
||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||
masks *Masks
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
ropeDim := cmp.Or(o.ropeDim, o.headDim())
|
||||
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
// Operator is the interface for attention-like operators
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// sparse implements MoE with shared experts
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
|
||||
// Shared experts
|
||||
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
|
||||
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
|
||||
// Router logits
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Softmax routing weights
|
||||
routingWeights := routerLogits.Softmax(ctx)
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
|
||||
|
||||
// Expert computation with SILU activation
|
||||
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
experts := gateOut.SILU(ctx, upOut)
|
||||
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if mlp.SharedUp != nil {
|
||||
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
|
||||
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedOut := sharedGate.SILU(ctx, sharedUp)
|
||||
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
|
||||
|
||||
// Apply shared expert gating
|
||||
if mlp.SharedGateInp != nil {
|
||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||
sharedGateVal = sharedGateVal.Sigmoid(ctx)
|
||||
// Broadcast gate to match dimensions
|
||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||
}
|
||||
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
|
||||
// dense implements standard feedforward
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
|
||||
Operator Operator
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-attention norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Attention (full or linear)
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// First residual connection
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// Save for FFN residual
|
||||
ffnResidual := hiddenStates
|
||||
|
||||
// Post-attention norm (before FFN)
|
||||
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// FFN
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
// Second residual connection
|
||||
return hiddenStates.Add(ctx, ffnResidual), nil
|
||||
}
|
||||
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
// Create masks once per forward pass
|
||||
m.Options.masks = createMasks(ctx)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer head counts (for detecting layer type)
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
layers[i].Operator = &GatedDeltaNet{Layer: i}
|
||||
} else {
|
||||
layers[i].Operator = &FullAttention{}
|
||||
}
|
||||
|
||||
if isMoE {
|
||||
layers[i].MLP = &sparse{}
|
||||
} else {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: func() int {
|
||||
for _, v := range headCountKV {
|
||||
if v > 0 {
|
||||
return int(v)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
isRecurrent: isRecurrent,
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.convKernelSize-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
headVDim := 0
|
||||
numVHeads := opts.ssmDtRank
|
||||
if numVHeads > 0 {
|
||||
headVDim = opts.ssmDInner / numVHeads
|
||||
}
|
||||
deltaStateSize := headVDim * headVDim * numVHeads
|
||||
|
||||
// Validate dimension assumption: headKDim == headVDim is required for state computations
|
||||
headKDim := opts.ssmDState
|
||||
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
|
||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
|
||||
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen3next", New)
|
||||
}
|
||||
@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
sb.WriteString(imStartTag + "user")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
sb.WriteString("\n</tool_response>")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -127,8 +128,7 @@ fahrenheit
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -280,8 +279,7 @@ call tool<|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: testArgs(map[string]any{"payload": "ok"}),
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
|
||||
}
|
||||
|
||||
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
|
||||
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
|
||||
}
|
||||
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
|
||||
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ToolDefinitionTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
if numPast > 0 {
|
||||
// Recurrent caches use checkpoints to pick a safe resume position.
|
||||
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
|
||||
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
|
||||
numPast = restored
|
||||
} else {
|
||||
numPast = 0
|
||||
}
|
||||
} else if !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
}
|
||||
}
|
||||
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
|
||||
@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// Clip images are represented as 768 tokens, each an embedding
|
||||
imageNumTokens := 768
|
||||
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
// always include the last message
|
||||
if i == n {
|
||||
continue
|
||||
}
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(m.Images)
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
}
|
||||
|
||||
currMsgIdx := n
|
||||
if currMsgIdx > 0 {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
model := Model{Template: tmpl}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
maxTokenizes int
|
||||
}{
|
||||
{
|
||||
name: "all messages fit",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 1,
|
||||
},
|
||||
{
|
||||
name: "truncate to last message",
|
||||
limit: 5,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizeCount := 0
|
||||
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
||||
tokenizeCount++
|
||||
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tokenizeCount > tt.maxTokenizes {
|
||||
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,48 @@ func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
||||
switch {
|
||||
// Full attention
|
||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_k.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_v.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".attn_output.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// Linear attention (Gated Delta Net) after split
|
||||
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// SSM
|
||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// MoE experts + shared experts
|
||||
case strings.HasSuffix(name, ".ffn_down_exps.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
|
||||
// Ported from llama_tensor_get_type, removed unsupported quantization types
|
||||
nExperts := max(1, kv.Uint("expert_count", 0))
|
||||
@@ -217,6 +259,7 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
|
||||
|
||||
// do not quantize positional embeddings and token types (BERT)
|
||||
quantize = quantize && (name != "position_embd.weight")
|
||||
@@ -244,6 +287,12 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3nextQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
|
||||
if newType != defaultType {
|
||||
|
||||
@@ -3,13 +3,13 @@ package model
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
Reference in New Issue
Block a user