mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
18 Commits
v0.16.2
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00f67e807a | ||
|
|
97323d1c68 | ||
|
|
458dd1b9d9 | ||
|
|
9d02d1d767 | ||
|
|
1a636fb47a | ||
|
|
0759fface9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 |
@@ -16,7 +16,7 @@ Start building with open models.
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
```
|
||||
|
||||
or [download manually](http://localhost:8080/download/Ollama.dmg)
|
||||
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
|
||||
12
cmd/cmd.go
12
cmd/cmd.go
@@ -57,9 +57,9 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1897,9 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
123
cmd/config/cline.go
Normal file
123
cmd/config/cline.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0]
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0]
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
204
cmd/config/cline_test.go
Normal file
204
cmd/config/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"cline": &Cline{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
// integrationInstallHints maps integration names to install URLs.
|
||||
var integrationInstallHints = map[string]string{
|
||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||
"cline": "https://cline.bot/cli",
|
||||
"openclaw": "https://docs.openclaw.ai",
|
||||
"codex": "https://developers.openai.com/codex/cli/",
|
||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
"opencode": "https://opencode.ai",
|
||||
"pi": "https://github.com/badlogic/pi-mono",
|
||||
}
|
||||
|
||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
|
||||
// integrationDescriptions maps integration names to short descriptions.
|
||||
var integrationDescriptions = map[string]string{
|
||||
"claude": "Anthropic's coding tool with subagents",
|
||||
"cline": "Autonomous coding agent with parallel execution",
|
||||
"codex": "OpenAI's open-source coding agent",
|
||||
"openclaw": "Personal AI with 100+ skills",
|
||||
"droid": "Factory's coding agent across terminal and IDEs",
|
||||
"opencode": "Anomaly's open-source coding agent",
|
||||
"pi": "Minimal AI agent toolkit with plugin support",
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
||||
// integrationOrder defines a custom display order for integrations.
|
||||
// Integrations listed here are placed at the end in the given order;
|
||||
// all others appear first, sorted alphabetically.
|
||||
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
|
||||
// with integrationOrder entries placed at the end.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
var result []IntegrationInfo
|
||||
for name, r := range integrations {
|
||||
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
|
||||
Description: integrationDescriptions[name],
|
||||
})
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
|
||||
}
|
||||
|
||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
// Both have custom order: sort by their rank
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
// Only one has custom order: it goes last
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
// Neither has custom order: alphabetical
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return result
|
||||
@@ -186,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "cline":
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
case "pi":
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
@@ -214,7 +248,8 @@ type ModelItem struct {
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
@@ -257,7 +292,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
selected, err := selector("Select model to run:", items, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -367,13 +402,11 @@ func selectIntegration() (string, error) {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []ModelItem
|
||||
for _, name := range names {
|
||||
for name, r := range integrations {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
@@ -381,7 +414,25 @@ func selectIntegration() (string, error) {
|
||||
items = append(items, ModelItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items)
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
slices.SortFunc(items, func(a, b ModelItem) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items, "")
|
||||
}
|
||||
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
@@ -439,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := single(prompt, items)
|
||||
model, err := single(prompt, items, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -812,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
@@ -915,11 +968,9 @@ Examples:
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
cloudCleared := false
|
||||
if model != "" && modelFlag == "" {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
||||
model = ""
|
||||
cloudCleared = true
|
||||
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||
@@ -928,18 +979,16 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
// Show picker so user can change model (skip when --model flag provided)
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
@@ -1001,27 +1050,13 @@ Examples:
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
savedModels := filterDisabledCloudModels(saved.Models)
|
||||
if len(savedModels) != len(saved.Models) {
|
||||
_ = SaveIntegration(name, savedModels)
|
||||
}
|
||||
if len(savedModels) == 0 {
|
||||
// All saved models were cloud — fall through to picker
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
models = savedModels
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
} else {
|
||||
current := ""
|
||||
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
|
||||
current = saved.Models[0]
|
||||
}
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
models, err = selectModels(cmd.Context(), name, current)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted by name", func(t *testing.T) {
|
||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||
// All other entries should be sorted alphabetically before them.
|
||||
orderRank := make(map[string]int)
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
for i := 1; i < len(infos); i++ {
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||
switch {
|
||||
case aRank == 0 && bRank == 0:
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
case aRank > 0 && bRank == 0:
|
||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||
case aRank > 0 && bRank > 0:
|
||||
if aRank >= bRank {
|
||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -579,7 +619,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
@@ -591,6 +633,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -602,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -624,6 +679,11 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -648,7 +708,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -671,7 +731,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +751,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -703,15 +763,18 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -734,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- cursorForCurrent ---
|
||||
|
||||
func TestCursorForCurrent(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "llama3.2", Recommended: true},
|
||||
{Name: "qwen3:8b", Recommended: true},
|
||||
{Name: "gemma3:latest"},
|
||||
{Name: "deepseek-r1"},
|
||||
{Name: "glm-5:cloud"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current string
|
||||
want int
|
||||
}{
|
||||
{"empty current", "", 0},
|
||||
{"exact match", "qwen3:8b", 1},
|
||||
{"no match returns 0", "nonexistent", 0},
|
||||
{"bare name matches with :latest suffix", "gemma3", 2},
|
||||
{"full tag matches bare item", "llama3.2:latest", 0},
|
||||
{"cloud model exact match", "glm-5:cloud", 4},
|
||||
{"cloud model bare name", "glm-5", 4},
|
||||
{"recommended item exact match", "llama3.2", 0},
|
||||
{"recommended item with tag", "qwen3", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReorderItems ---
|
||||
|
||||
func TestReorderItems(t *testing.T) {
|
||||
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("first checked item should have (default) tag")
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,6 +593,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
@@ -565,6 +610,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
@@ -582,6 +628,161 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -429,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
|
||||
@@ -106,20 +106,23 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Assistants",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Coding",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
"/integrations/goose",
|
||||
"/integrations/pi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
- [Pi](/integrations/pi)
|
||||
|
||||
## Assistants
|
||||
|
||||
|
||||
57
docs/integrations/pi.mdx
Normal file
57
docs/integrations/pi.mdx
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Pi
|
||||
---
|
||||
|
||||
Pi is a minimal AI agent toolkit with plugin support.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Pi](https://github.com/badlogic/pi-mono):
|
||||
|
||||
```bash
|
||||
npm install -g @mariozechner/pi-coding-agent
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch pi
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch pi --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{
|
||||
"id": "qwen3-coder"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Update `~/.pi/agent/settings.json` to set the default provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultProvider": "ollama",
|
||||
"defaultModel": "qwen3-coder"
|
||||
}
|
||||
```
|
||||
@@ -27,9 +27,17 @@ The menu provides quick access to:
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
## Assistants
|
||||
|
||||
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
|
||||
|
||||
```sh
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
## Coding
|
||||
|
||||
Launch coding tools with Ollama models:
|
||||
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude
|
||||
|
||||
1
go.mod
1
go.mod
@@ -26,6 +26,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -122,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -150,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3":
|
||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
name string
|
||||
}{
|
||||
{"passthrough"},
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
335
model/parsers/qwen3.go
Normal file
335
model/parsers/qwen3.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen3ParserState int
|
||||
|
||||
const (
|
||||
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
|
||||
qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingThinking
|
||||
qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
qwen3ParserStateCollectingContent
|
||||
qwen3ParserStateToolStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen3ThinkingOpenTag = "<think>"
|
||||
qwen3ThinkingCloseTag = "</think>"
|
||||
qwen3ToolOpenTag = "<tool_call>"
|
||||
qwen3ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
|
||||
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
|
||||
// with thinking content directly (without an opening tag).
|
||||
type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = p.defaultThinking
|
||||
}
|
||||
|
||||
if p.hasThinkingSupport && thinkingEnabled {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
p.maybeThinkingOpenAtBOL = true
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen3Event interface {
|
||||
isQwen3Event()
|
||||
}
|
||||
|
||||
type qwen3EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventContent) isQwen3Event() {}
|
||||
|
||||
type qwen3EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (qwen3EventRawToolCall) isQwen3Event() {}
|
||||
|
||||
type qwen3EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventThinkingContent) isQwen3Event() {}
|
||||
|
||||
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen3EventRawToolCall:
|
||||
toolCall, err := parseQwen3ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwen3EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) parseEvents() []qwen3Event {
|
||||
var all []qwen3Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen3Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
var events []qwen3Event
|
||||
|
||||
switch p.state {
|
||||
case qwen3ParserStateLookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
return events, false
|
||||
}
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
|
||||
case qwen3ParserStateThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
|
||||
|
||||
case qwen3ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
// Some qwen3 checkpoints emit an explicit opening <think> tag even
|
||||
// though the prompt already ended with <think>. Strip exactly one
|
||||
// leading opening tag if present.
|
||||
if p.maybeThinkingOpenAtBOL {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
return events, true
|
||||
}
|
||||
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
|
||||
|
||||
case qwen3ParserStateCollectingContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolOpenTag) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
|
||||
|
||||
case qwen3ParserStateCollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("qwen3 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, qwen3EventRawToolCall{raw: toolContent})
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Name == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
|
||||
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range parsed.Arguments {
|
||||
toolCall.Function.Arguments.Set(key, value)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
147
model/parsers/qwen3_test.go
Normal file
147
model/parsers/qwen3_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3ParserThinkingEnabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingDisabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCall(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
unit, ok := calls[0].Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,10 @@
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
# Wrap script in main function so that a truncated partial download doesn't end
|
||||
# up executing half a script.
|
||||
main() {
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
@@ -446,3 +450,6 @@ fi
|
||||
|
||||
status "NVIDIA GPU ready."
|
||||
install_success
|
||||
}
|
||||
|
||||
main
|
||||
|
||||
@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loadedModel, err := GetModel("test-image")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
schedulerModelKey(loadedModel): {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: loadedModel,
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
|
||||
@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
return sched
|
||||
}
|
||||
|
||||
// schedulerModelKey returns the scheduler map key for a model.
|
||||
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||
// ModelPath use manifest digest so distinct models don't collide.
|
||||
func schedulerModelKey(m *Model) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ModelPath != "" {
|
||||
return m.ModelPath
|
||||
}
|
||||
if m.Digest != "" {
|
||||
return "digest:" + m.Digest
|
||||
}
|
||||
if m.Name != "" {
|
||||
return "name:" + m.Name
|
||||
}
|
||||
if m.ShortName != "" {
|
||||
return "short:" + m.ShortName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[req.model.ModelPath]
|
||||
runner := s.loaded[key]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
pendingKey := schedulerModelKey(pending.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
runner := s.loaded[pendingKey]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
|
||||
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
slog.Debug("shutting down scheduler completed loop")
|
||||
return
|
||||
case finished := <-s.finishedReqCh:
|
||||
finishedKey := schedulerModelKey(finished.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[finished.model.ModelPath]
|
||||
runner := s.loaded[finishedKey]
|
||||
s.loadedMu.Unlock()
|
||||
if runner == nil {
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||
continue
|
||||
}
|
||||
runner.refMu.Lock()
|
||||
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelPath]
|
||||
runnerToUnload := s.loaded[runner.modelKey]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
delete(s.loaded, runner.modelKey)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
@@ -514,6 +539,7 @@ iGPUScan:
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
@@ -528,7 +554,7 @@ iGPUScan:
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
||||
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
@@ -536,7 +562,7 @@ iGPUScan:
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
@@ -684,6 +711,7 @@ type runnerRef struct {
|
||||
|
||||
model *Model
|
||||
modelPath string
|
||||
modelKey string
|
||||
numParallel int
|
||||
*api.Options
|
||||
}
|
||||
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
modelID := runner.modelPath
|
||||
if modelID == "" {
|
||||
modelID = runner.modelKey
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
slog.String("model", modelID),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model path lex order
|
||||
return a[i].modelPath < a[j].modelPath
|
||||
// Secondary sort by model key/path lex order
|
||||
n1 := a[i].modelPath
|
||||
if n1 == "" {
|
||||
n1 = a[i].modelKey
|
||||
}
|
||||
n2 := a[j].modelPath
|
||||
if n2 == "" {
|
||||
n2 = a[j].modelKey
|
||||
}
|
||||
return n1 < n2
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
|
||||
}
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
modelKey := schedulerModelKey(model)
|
||||
s.loadedMu.Lock()
|
||||
runner, ok := s.loaded[model.ModelPath]
|
||||
runner, ok := s.loaded[modelKey]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
|
||||
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
case runner := <-successCh:
|
||||
require.Equal(t, loadedRunner, runner)
|
||||
default:
|
||||
t.Fatal("expected existing runner to be reused")
|
||||
}
|
||||
|
||||
require.Empty(t, errCh)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
}
|
||||
|
||||
func TestSchedExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
@@ -30,6 +30,8 @@ type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
@@ -37,7 +39,7 @@ type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
}
|
||||
}
|
||||
|
||||
func resolveParserName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Parser != "" {
|
||||
return mf.Parser
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Renderer != "" {
|
||||
return mf.Renderer
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
Parser: "qwen3",
|
||||
Renderer: "qwen3",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
if config.Parser != "qwen3" {
|
||||
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
|
||||
}
|
||||
if config.Renderer != "qwen3" {
|
||||
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Errorf("Parser should be empty, got %q", config.Parser)
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Errorf("Renderer should be empty, got %q", config.Renderer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Error("Parser should be empty")
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Error("Renderer should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
Parser: "qwen3-thinking",
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
if opts.Modelfile.Parser != "qwen3-thinking" {
|
||||
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
|
||||
}
|
||||
if opts.Modelfile.Renderer != "qwen3" {
|
||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "empty parser uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "explicit parser overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "qwen3-thinking",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3-thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRendererName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "empty renderer uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "explicit renderer overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
|
||||
20
x/mlxrunner/cache/cache.go
vendored
20
x/mlxrunner/cache/cache.go
vendored
@@ -4,13 +4,19 @@ package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func kvCacheGrowDebugEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Materialize() []*mlx.Array
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
@@ -48,6 +54,9 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
@@ -66,6 +75,17 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.keys != nil && c.keys.Valid() {
|
||||
out = append(out, c.keys)
|
||||
}
|
||||
if c.values != nil && c.values.Valid() {
|
||||
out = append(out, c.values)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestKVCacheGrowDebugEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
|
||||
if !kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
162
x/mlxrunner/cache/recurrent.go
vendored
Normal file
162
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
// Break dependency chains so recurrent state does not retain the full
|
||||
// per-token compute graph over time.
|
||||
snap := mlx.Snapshot(v)
|
||||
mlx.Eval(snap)
|
||||
|
||||
old := *dst
|
||||
*dst = snap
|
||||
|
||||
// Release previous cached state root, then recursively free the transient
|
||||
// incoming graph root now that a detached snapshot is retained in cache.
|
||||
if old != nil && old != snap {
|
||||
mlx.Release(old)
|
||||
}
|
||||
if v != snap && v != old {
|
||||
mlx.Free(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
old := *dst
|
||||
*dst = v
|
||||
if old != nil && old != v {
|
||||
mlx.Release(old)
|
||||
}
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
if c.convState == nil || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim {
|
||||
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
|
||||
if c.deltaState == nil || c.deltaState.DType() != dtype ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim {
|
||||
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.convState, v)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.deltaState, v)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.ensure(1, mlx.DTypeFloat32)
|
||||
return c.convState, c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
out = append(out, c.convState)
|
||||
}
|
||||
if c.deltaState != nil && c.deltaState.Valid() {
|
||||
out = append(out, c.deltaState)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
// Recurrent state cannot be reversed cheaply; reset to a clean state when trimming.
|
||||
if n > 0 {
|
||||
if c.convState != nil {
|
||||
c.setStateRaw(&c.convState, mlx.Zeros(c.convState.DType(), c.convState.Dim(0), c.convState.Dim(1), c.convState.Dim(2)))
|
||||
}
|
||||
if c.deltaState != nil {
|
||||
c.setStateRaw(&c.deltaState, mlx.Zeros(c.deltaState.DType(), c.deltaState.Dim(0), c.deltaState.Dim(1), c.deltaState.Dim(2), c.deltaState.Dim(3)))
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
}
|
||||
if c.convState != nil {
|
||||
clone.convState = c.convState.Clone()
|
||||
}
|
||||
if c.deltaState != nil {
|
||||
clone.deltaState = c.deltaState.Clone()
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
@@ -3,5 +3,10 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
@@ -272,3 +272,39 @@ func Free(s ...*Array) (n int) {
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Release forcibly frees arrays regardless of reference accounting.
|
||||
// Use only for arrays that are known to be unreachable by any live model state.
|
||||
func Release(s ...*Array) (n int) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
t.desc.inputs = nil
|
||||
t.desc.numRefs = 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
const pinnedNumRefs = 1 << 30
|
||||
|
||||
// Pin keeps arrays alive for the process lifetime by setting a very high
|
||||
// reference count floor. Use for model parameter tensors shared across many
|
||||
// decode steps, where recursive Free traversals must never reclaim them.
|
||||
func Pin(s ...*Array) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
if t.desc.numRefs < pinnedNumRefs {
|
||||
t.desc.numRefs = pinnedNumRefs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,30 @@ func tryLoadFromDir(dir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tryLoadByName attempts to load the library using just its name,
|
||||
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadByName() bool {
|
||||
libraryName := "libmlxc.dylib"
|
||||
if runtime.GOOS == "linux" {
|
||||
libraryName = "libmlxc.so"
|
||||
}
|
||||
|
||||
cPath := C.CString(libraryName)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
return false
|
||||
}
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -73,6 +97,11 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading via rpath/standard library search
|
||||
if tryLoadByName() {
|
||||
return
|
||||
}
|
||||
|
||||
// Build search paths: executable directory, then build directories
|
||||
var searchDirs []string
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
|
||||
@@ -279,6 +279,24 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP", a)
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG", a)
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS", a)
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -386,6 +404,52 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||
func Snapshot(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("SNAPSHOT")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// CollectReachable collects arrays from v and all transitive graph inputs.
|
||||
func CollectReachable(v any) []*Array {
|
||||
roots := Collect(v)
|
||||
if len(roots) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[*Array]bool, len(roots))
|
||||
out := make([]*Array, 0, len(roots))
|
||||
stack := append([]*Array(nil), roots...)
|
||||
for len(stack) > 0 {
|
||||
a := stack[len(stack)-1]
|
||||
stack = stack[:len(stack)-1]
|
||||
|
||||
if a == nil || !a.Valid() || seen[a] {
|
||||
continue
|
||||
}
|
||||
seen[a] = true
|
||||
out = append(out, a)
|
||||
stack = append(stack, a.desc.inputs...)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Detach returns a new Array handle that shares the same MLX value but does
|
||||
// not retain Go-side graph input references.
|
||||
func Detach(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("DETACH")
|
||||
C.mlx_array_set(&out.ctx, a.ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
|
||||
92
x/mlxrunner/model/linear.go
Normal file
92
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
130
x/mlxrunner/model/quant.go
Normal file
130
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||
// when available, otherwise falling back to the provided model defaults.
|
||||
func TensorQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||
if tensorQuant != nil {
|
||||
if tq := tensorQuant[tensorName]; tq != nil {
|
||||
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||
if tq.GroupSize > 0 {
|
||||
groupSize = tq.GroupSize
|
||||
}
|
||||
return groupSize, bits, mode, true
|
||||
}
|
||||
}
|
||||
return defaultGroupSize, defaultBits, defaultMode, false
|
||||
}
|
||||
|
||||
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||
// inference for affine packed tensors.
|
||||
func ResolveLinearQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
weight, scales *mlx.Array,
|
||||
) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
tensorName,
|
||||
)
|
||||
|
||||
if mode == "affine" {
|
||||
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||
groupSize = inferredGroupSize
|
||||
bits = inferredBits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||
// tensors from packed weight and scale shapes.
|
||||
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||
if weight == nil || scales == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scales.Dims()
|
||||
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightCols := weightShape[len(weightShape)-1]
|
||||
scalesCols := scaleShape[len(scaleShape)-1]
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return 32, 4, true
|
||||
case groupSize8 == 64:
|
||||
return 64, 8, true
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
if hintBits == 8 {
|
||||
return 32, 8, true
|
||||
}
|
||||
if hintBits == 4 {
|
||||
return 64, 4, true
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return groupSize4, 4, true
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return groupSize8, 8, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func isCommonGroupSize(v int) bool {
|
||||
switch v {
|
||||
case 16, 32, 64, 128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -8,42 +8,63 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
quantType string
|
||||
groupSize int
|
||||
// TensorQuantInfo describes per-tensor quantization metadata.
|
||||
type TensorQuantInfo struct {
|
||||
QuantType string
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and pre-scans the first
|
||||
// tensor blob for quantization metadata (quant_type, group_size).
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
|
||||
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||
quantType string
|
||||
groupSize int
|
||||
|
||||
// Per-tensor quantization metadata.
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and scans tensor blobs for
|
||||
// quantization metadata.
|
||||
func Open(modelName string) (*Root, error) {
|
||||
m, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := &Root{Manifest: m}
|
||||
root := &Root{
|
||||
Manifest: m,
|
||||
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||
}
|
||||
|
||||
// Pre-scan first tensor blob for quantization metadata
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
meta, err := readBlobMetadata(blobPath)
|
||||
if err != nil || meta == nil {
|
||||
|
||||
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if qt := meta["quant_type"]; qt != "" {
|
||||
root.quantType = strings.ToUpper(qt)
|
||||
|
||||
for name, info := range infos {
|
||||
root.tensorQuant[name] = info
|
||||
}
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
fmt.Sscanf(gs, "%d", &root.groupSize)
|
||||
|
||||
if root.quantType == "" && blobQuantType != "" {
|
||||
root.quantType = strings.ToUpper(blobQuantType)
|
||||
root.groupSize = blobGroupSize
|
||||
if root.groupSize == 0 {
|
||||
root.groupSize = defaultGroupSize(root.quantType)
|
||||
}
|
||||
}
|
||||
break // only check the first tensor blob
|
||||
}
|
||||
|
||||
return root, nil
|
||||
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
// QuantType returns the quantization type detected from tensor metadata.
|
||||
// QuantType returns the quantization type detected from the first tensor blob metadata.
|
||||
func (r *Root) QuantType() string { return r.quantType }
|
||||
|
||||
// GroupSize returns the quantization group size detected from tensor metadata.
|
||||
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
|
||||
func (r *Root) GroupSize() int { return r.groupSize }
|
||||
|
||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
||||
func readBlobMetadata(path string) (map[string]string, error) {
|
||||
// TensorQuant returns per-tensor quantization metadata if available.
|
||||
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.tensorQuant[name]
|
||||
}
|
||||
|
||||
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
|
||||
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
|
||||
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
|
||||
for k, v := range r.tensorQuant {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
copy := *v
|
||||
out[k] = ©
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultGroupSize(quantType string) int {
|
||||
groupSize, _, _ := QuantizationParams(quantType)
|
||||
return groupSize
|
||||
}
|
||||
|
||||
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
if headerSize > 100*1024*1024 {
|
||||
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
if _, ok := header[name+".scale"]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = inferredGroup
|
||||
}
|
||||
if quantType == "" {
|
||||
continue
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = defaultGroupSize(quantType)
|
||||
}
|
||||
|
||||
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
return infos, globalQuantType, globalGroupSize, nil
|
||||
}
|
||||
|
||||
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return nil, err
|
||||
return "", 0
|
||||
}
|
||||
return meta, nil
|
||||
|
||||
quantType = meta["quant_type"]
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
groupSize, _ = strconv.Atoi(gs)
|
||||
}
|
||||
return quantType, groupSize
|
||||
}
|
||||
|
||||
func mainTensorNames(header map[string]json.RawMessage) []string {
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
|
||||
type tensorShape struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
mainRaw, ok := header[tensorName]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
scaleRaw, ok := header[tensorName+".scale"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var mainInfo tensorShape
|
||||
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var scaleInfo tensorShape
|
||||
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
|
||||
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return "INT4", 32
|
||||
case groupSize8 == 64:
|
||||
return "INT8", 64
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
h := strings.ToUpper(hintQuantType)
|
||||
if strings.Contains(h, "8") {
|
||||
return "INT8", 32
|
||||
}
|
||||
if strings.Contains(h, "4") {
|
||||
return "INT4", 64
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return "INT4", groupSize4
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return "INT8", groupSize8
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
|
||||
@@ -6,47 +6,134 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func prefillChunkSize(lowMemoryDecode bool) int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
if lowMemoryDecode {
|
||||
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
|
||||
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
|
||||
return 32
|
||||
}
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func mlxDebugMemoryEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
|
||||
}
|
||||
|
||||
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
|
||||
if usePromptCache {
|
||||
insertCache()
|
||||
logMemory("request_done_cached", -1)
|
||||
return
|
||||
}
|
||||
freeCaches()
|
||||
logMemory("request_done_freed", -1)
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
mlx.EnableCompile()
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
usePromptCache := true
|
||||
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
|
||||
usePromptCache = false
|
||||
}
|
||||
lowMemoryDecode := !usePromptCache
|
||||
prefillChunk := prefillChunkSize(lowMemoryDecode)
|
||||
|
||||
var caches []cache.Cache
|
||||
var tokens []int32
|
||||
if usePromptCache {
|
||||
caches, tokens = r.FindNearestCache(inputs)
|
||||
} else {
|
||||
tokens = inputs
|
||||
}
|
||||
|
||||
if len(caches) == 0 {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
freeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
// Non-prompt-cache requests allocate fresh caches every generation.
|
||||
// Explicitly free cache roots so graph chains are reclaimed promptly.
|
||||
mlx.Free(state...)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
debugMemory := mlxDebugMemoryEnabled()
|
||||
logMemory := func(phase string, token int) {
|
||||
if !debugMemory {
|
||||
return
|
||||
}
|
||||
if token >= 0 {
|
||||
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
|
||||
return
|
||||
}
|
||||
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
|
||||
}
|
||||
logMemory("prefill_start", -1)
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
materializeCaches()
|
||||
mlx.Free(temp)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
logMemory("prefill_done", -1)
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
@@ -58,7 +145,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
if !lowMemoryDecode {
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
} else {
|
||||
// Materialize cache updates to prevent transform graph growth.
|
||||
materializeCaches()
|
||||
}
|
||||
logMemory("decode_init", -1)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -66,12 +159,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
logMemory("decode_first_eval", i)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
@@ -83,6 +174,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
mlx.Free(sample, logprobs)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -91,18 +183,43 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
// For recurrent linear-attention models, avoid async prefetch to reduce
|
||||
// peak memory and clear allocator cache every token.
|
||||
if lowMemoryDecode {
|
||||
mlx.Free(sample, logprobs)
|
||||
if i+1 >= request.Options.MaxTokens {
|
||||
break
|
||||
}
|
||||
mlx.ClearCache()
|
||||
sample, logprobs = step(mlx.FromValues([]int32{output}, 1))
|
||||
// Materialize cache updates to avoid unbounded transform chains.
|
||||
materializeCaches()
|
||||
if i%32 == 0 {
|
||||
logMemory("decode_lowmem_step", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
if i%64 == 0 {
|
||||
logMemory("decode_async_step", i)
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
finalizeRequestCaches(usePromptCache,
|
||||
func() { r.InsertCache(append(inputs, outputs...), caches) },
|
||||
freeCaches,
|
||||
logMemory,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -114,13 +231,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
return flushValidUTF8Prefix(b)
|
||||
}
|
||||
|
||||
83
x/mlxrunner/pipeline_helpers_test.go
Normal file
83
x/mlxrunner/pipeline_helpers_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPrefillChunkSize(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
|
||||
if got := prefillChunkSize(false); got != 2<<10 {
|
||||
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 32 {
|
||||
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
|
||||
if got := prefillChunkSize(false); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXDebugMemoryEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
|
||||
if mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
|
||||
if !mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
true,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 1 {
|
||||
t.Fatalf("insert calls = %d, want 1", insertCalls)
|
||||
}
|
||||
if freeCalls != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_cached" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
false,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 0 {
|
||||
t.Fatalf("insert calls = %d, want 0", insertCalls)
|
||||
}
|
||||
if freeCalls != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_freed" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
|
||||
}
|
||||
}
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
@@ -64,6 +64,38 @@ type Runner struct {
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
|
||||
if len(tensors) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seen := make(map[*mlx.Array]bool, len(tensors))
|
||||
toRelease := make([]*mlx.Array, 0, len(tensors))
|
||||
for name, arr := range tensors {
|
||||
if arr == nil || !arr.Valid() {
|
||||
delete(tensors, name)
|
||||
continue
|
||||
}
|
||||
if keep != nil {
|
||||
if _, ok := keep[arr]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(tensors, name)
|
||||
if seen[arr] {
|
||||
continue
|
||||
}
|
||||
seen[arr] = true
|
||||
toRelease = append(toRelease, arr)
|
||||
}
|
||||
|
||||
if len(toRelease) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return len(toRelease), mlx.Release(toRelease...)
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
root, err := model.Open(modelName)
|
||||
if err != nil {
|
||||
@@ -85,9 +117,29 @@ func (r *Runner) Load(modelName string) error {
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
|
||||
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
mlx.ClearCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Pin only model-owned tensor roots. Pinning the full transitive graph can
|
||||
// retain large load-time intermediates and inflate steady-state memory.
|
||||
roots := mlx.Collect(m)
|
||||
mlx.Pin(roots...)
|
||||
|
||||
keep := make(map[*mlx.Array]struct{})
|
||||
for _, arr := range roots {
|
||||
if arr != nil && arr.Valid() {
|
||||
keep[arr] = struct{}{}
|
||||
}
|
||||
}
|
||||
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
|
||||
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
return nil
|
||||
|
||||
47
x/mlxrunner/utf8_buffer.go
Normal file
47
x/mlxrunner/utf8_buffer.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
|
||||
// currently buffered, leaving any incomplete trailing bytes in place.
|
||||
func flushValidUTF8Prefix(b *bytes.Buffer) string {
|
||||
data := b.Bytes()
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := validUTF8PrefixLen(data)
|
||||
if prefix == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := string(data[:prefix])
|
||||
b.Next(prefix)
|
||||
return text
|
||||
}
|
||||
|
||||
func validUTF8PrefixLen(data []byte) int {
|
||||
i := 0
|
||||
prefix := 0
|
||||
for i < len(data) {
|
||||
r, size := utf8.DecodeRune(data[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
if !utf8.FullRune(data[i:]) {
|
||||
break
|
||||
}
|
||||
|
||||
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
|
||||
i++
|
||||
prefix = i
|
||||
continue
|
||||
}
|
||||
|
||||
i += size
|
||||
prefix = i
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
46
x/mlxrunner/utf8_buffer_test.go
Normal file
46
x/mlxrunner/utf8_buffer_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
||||
b.Write([]byte{0xE3, 0x81})
|
||||
if got := flushValidUTF8Prefix(&b); got != "" {
|
||||
t.Fatalf("first flush = %q, want empty", got)
|
||||
}
|
||||
|
||||
b.Write([]byte{0x93, 0xE3})
|
||||
if got := flushValidUTF8Prefix(&b); got != "こ" {
|
||||
t.Fatalf("second flush = %q, want %q", got, "こ")
|
||||
}
|
||||
|
||||
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
|
||||
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
|
||||
}
|
||||
|
||||
b.Write([]byte{0x82, 0x93})
|
||||
if got := flushValidUTF8Prefix(&b); got != "ん" {
|
||||
t.Fatalf("third flush = %q, want %q", got, "ん")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after third flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("hello 世界")
|
||||
|
||||
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
|
||||
t.Fatalf("flush = %q, want %q", got, "hello 世界")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
521
x/models/gemma3/gemma3.go
Normal file
521
x/models/gemma3/gemma3.go
Normal file
@@ -0,0 +1,521 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Gemma3ForCausalLM", newModel)
|
||||
base.Register("Gemma3ForConditionalGeneration", newModel)
|
||||
}
|
||||
|
||||
// TextConfig holds configuration for the Gemma 3 text model.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
|
||||
QNorm *nn.RMSNorm
|
||||
KNorm *nn.RMSNorm
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
QNormScaled *mlx.Array
|
||||
KNormScaled *mlx.Array
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation.
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block.
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm
|
||||
PreFFNorm *nn.RMSNorm
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
|
||||
// Layer metadata.
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Model is the Gemma 3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*DecoderLayer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
NormScaled *mlx.Array
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
|
||||
switch numLayers {
|
||||
case 34:
|
||||
return 8, 4
|
||||
case 48:
|
||||
return 16, 8
|
||||
case 62:
|
||||
return 32, 16
|
||||
default:
|
||||
return 8, 4
|
||||
}
|
||||
}
|
||||
|
||||
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
var wrapped struct {
|
||||
TextConfig *TextConfig `json:"text_config"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &wrapped); err != nil {
|
||||
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
|
||||
}
|
||||
|
||||
fromConditional := wrapped.TextConfig != nil
|
||||
if fromConditional {
|
||||
cfg = *wrapped.TextConfig
|
||||
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = 256
|
||||
}
|
||||
if cfg.NumAttentionHeads == 0 {
|
||||
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.NumKeyValueHeads == 0 {
|
||||
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208
|
||||
}
|
||||
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
|
||||
cfg.SlidingWindowPattern = 6
|
||||
}
|
||||
if cfg.MaxPositionEmbeddings == 0 {
|
||||
cfg.MaxPositionEmbeddings = 131072
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = 256
|
||||
}
|
||||
if cfg.NumAttentionHeads == 0 {
|
||||
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.NumKeyValueHeads == 0 {
|
||||
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208
|
||||
}
|
||||
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
return cfg, fromConditional, nil
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
|
||||
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
|
||||
return cfg.LayerTypes[layerIdx] == "sliding_attention"
|
||||
}
|
||||
if cfg.SlidingWindowPattern <= 0 {
|
||||
return false
|
||||
}
|
||||
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
|
||||
}
|
||||
|
||||
func precomputeGemmaScaledWeights(m *Model) {
|
||||
if m.Norm != nil {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
}
|
||||
|
||||
var scaled []*mlx.Array
|
||||
if m.NormScaled != nil {
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
if layer == nil || layer.Attention == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if layer.InputNorm != nil {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.InputNormScaled)
|
||||
}
|
||||
if layer.PostAttnNorm != nil {
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PostAttnNormScaled)
|
||||
}
|
||||
if layer.PreFFNorm != nil {
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PreFFNormScaled)
|
||||
}
|
||||
if layer.PostFFNorm != nil {
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PostFFNormScaled)
|
||||
}
|
||||
|
||||
if layer.Attention.QNorm != nil {
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.Attention.QNormScaled)
|
||||
}
|
||||
if layer.Attention.KNorm != nil {
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
}
|
||||
|
||||
if len(scaled) > 0 {
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, _, err := parseTextConfig(configData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), m.TextConfig),
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Gemma usually ties output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &DecoderLayer{
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, m.TextConfig),
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.InputNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.PostAttnNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.PreFFNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
|
||||
}
|
||||
if layer.PostFFNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
precomputeGemmaScaledWeights(m)
|
||||
if m.NormScaled == nil {
|
||||
return fmt.Errorf("missing precomputed final norm weight")
|
||||
}
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// NewCaches creates cache objects for all layers.
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if m.SlidingWindow > 0 && layer.IsSliding {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
@@ -8,14 +8,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -64,9 +63,10 @@ type Config struct {
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type string.
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array
|
||||
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
|
||||
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
|
||||
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||
cfg.QuantGroupSize,
|
||||
cfg.QuantBits,
|
||||
cfg.QuantMode,
|
||||
cfg.TensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
||||
// Check if quantized and dequantize
|
||||
if scales := tensors[path+".weight_scale"]; scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
|
||||
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||
cfg.QuantGroupSize,
|
||||
cfg.QuantBits,
|
||||
cfg.QuantMode,
|
||||
cfg.TensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
|
||||
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: cfg.QuantGroupSize,
|
||||
Bits: cfg.QuantBits,
|
||||
Mode: cfg.QuantMode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
|
||||
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
||||
// no weights loaded yet). Called by the registry via base.New().
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
|
||||
|
||||
// Set up quantization parameters from pre-scanned metadata
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
} else {
|
||||
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
// Load tokenizer
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
|
||||
// layer creation.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
cfg := m.Config
|
||||
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
||||
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
if !useQuantized && cfg.TensorQuant != nil {
|
||||
for _, tq := range cfg.TensorQuant {
|
||||
if tq == nil {
|
||||
continue
|
||||
}
|
||||
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
||||
if supportsGatherQMM(mode, bits) {
|
||||
useQuantized = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load embedding
|
||||
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
||||
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
// Load LM head
|
||||
m.LMHead = makeLinear(tensors, "lm_head", cfg)
|
||||
m.LMHead = linears.Make("lm_head")
|
||||
|
||||
// Load layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
|
||||
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
|
||||
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
|
||||
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
|
||||
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
|
||||
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
|
||||
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
|
||||
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
||||
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
block.MLP = &DenseMLP{
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
|
||||
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
|
||||
UpProj: linears.Make(prefix + ".mlp.up_proj"),
|
||||
DownProj: linears.Make(prefix + ".mlp.down_proj"),
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
moeGate := &MoEGate{}
|
||||
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
|
||||
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
|
||||
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||
moeGate.EScoreCorrectionBias = bias
|
||||
}
|
||||
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
|
||||
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
|
||||
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
|
||||
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
323
x/models/llama/llama.go
Normal file
323
x/models/llama/llama.go
Normal file
@@ -0,0 +1,323 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package llama provides a Llama-style decoder-only transformer for MLX.
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("LlamaForCausalLM", newModel)
|
||||
}
|
||||
|
||||
// Config holds Llama model configuration.
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Model is a Llama text model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm
|
||||
MLPNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumAttentionHeads <= 0 {
|
||||
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads <= 0 {
|
||||
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-5
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Fallback used by many Llama checkpoints where output is tied.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &Layer{
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
338
x/models/qwen3/qwen3.go
Normal file
338
x/models/qwen3/qwen3.go
Normal file
@@ -0,0 +1,338 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides the Qwen3 text model implementation for MLX.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3ForCausalLM", newModel)
|
||||
}
|
||||
|
||||
// Config holds Qwen3 model configuration.
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
Scale float32 `json:"-"`
|
||||
QKNormEps float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Model is the Qwen3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
// Layer is a single Qwen3 decoder block.
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm
|
||||
MLPNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with Q/K norms.
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
QNorm *nn.RMSNorm
|
||||
KNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with SwiGLU activation.
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumAttentionHeads <= 0 {
|
||||
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads <= 0 {
|
||||
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim == 0 {
|
||||
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
cfg.QKNormEps = 1e-6
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Qwen3 checkpoints commonly tie output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &Layer{
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
|
||||
q = a.QNorm.Forward(q, cfg.QKNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.QKNormEps)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
1254
x/models/qwen3_5/qwen3_5.go
Normal file
1254
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
120
x/models/qwen3_5/qwen3_5_test.go
Normal file
120
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRuntimeToggles(t *testing.T) {
|
||||
m := &Model{}
|
||||
if !m.DisablePromptCache() {
|
||||
t.Fatal("DisablePromptCache() = false, want true")
|
||||
}
|
||||
if m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
||||
info := buildModelInfo(config, totalBytes, tensorCount)
|
||||
|
||||
// For quantized models, byte-based estimation can significantly undercount
|
||||
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
|
||||
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
|
||||
info["general.parameter_count"] = paramCount
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
return info
|
||||
}
|
||||
|
||||
// getParameterCountFromManifest counts model parameters from tensor shapes.
|
||||
// This accounts for quantized tensors by using unpacked shapes from
|
||||
// getTensorInfoFromManifest.
|
||||
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
|
||||
tensors, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, tensor := range tensors {
|
||||
if len(tensor.Shape) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
elements := int64(1)
|
||||
for _, dim := range tensor.Shape {
|
||||
if dim == 0 {
|
||||
elements = 0
|
||||
break
|
||||
}
|
||||
|
||||
if dim > uint64(math.MaxInt64) {
|
||||
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
|
||||
}
|
||||
|
||||
d := int64(dim)
|
||||
if elements > math.MaxInt64/d {
|
||||
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
|
||||
}
|
||||
elements *= d
|
||||
}
|
||||
|
||||
if elements == 0 {
|
||||
continue
|
||||
}
|
||||
if total > math.MaxInt64-elements {
|
||||
return 0, fmt.Errorf("total parameter count overflow")
|
||||
}
|
||||
total += elements
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
|
||||
@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Unquantized tensor: [4,5] = 20 params
|
||||
header1 := map[string]any{
|
||||
"model.embed_tokens.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{4, 5},
|
||||
"data_offsets": []int64{0, 40},
|
||||
},
|
||||
}
|
||||
header1JSON, _ := json.Marshal(header1)
|
||||
var buf1 bytes.Buffer
|
||||
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
|
||||
buf1.Write(header1JSON)
|
||||
|
||||
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
|
||||
blobPath1, err := manifest.BlobsPath(digest1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob1: %v", err)
|
||||
}
|
||||
|
||||
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
|
||||
header2 := map[string]any{
|
||||
"__metadata__": map[string]string{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10, 2},
|
||||
"data_offsets": []int64{0, 80},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{80, 100},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{100, 120},
|
||||
},
|
||||
}
|
||||
header2JSON, _ := json.Marshal(header2)
|
||||
var buf2 bytes.Buffer
|
||||
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
|
||||
buf2.Write(header2JSON)
|
||||
|
||||
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
|
||||
blobPath2, err := manifest.BlobsPath(digest2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob2: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest1,
|
||||
Size: int64(buf1.Len() + 40),
|
||||
Name: "model.embed_tokens.weight",
|
||||
},
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest2,
|
||||
Size: int64(buf2.Len() + 120),
|
||||
Name: "model.layers.0.mlp.up_proj.weight",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 180 // 20 + 160
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Packed mixed-precision blob (no global metadata):
|
||||
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
|
||||
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
|
||||
header := map[string]any{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 8},
|
||||
"data_offsets": []int64{0, 160},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{160, 180},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{180, 200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 16},
|
||||
"data_offsets": []int64{200, 520},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{520, 530},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{530, 540},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest,
|
||||
Size: int64(buf.Len() + 540),
|
||||
Name: "model.layers.0.mlp.experts",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 640 // 320 + 320
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
108
x/tokenizer/tokenizer.go
Normal file
108
x/tokenizer/tokenizer.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//go:build mlx
|
||||
|
||||
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
|
||||
//
|
||||
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
|
||||
// - GPT-2 byte-level encoding (OpenAI tiktoken)
|
||||
// - HuggingFace tokenizer.json pretokenizer patterns
|
||||
// - SentencePiece ▁-style space handling
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "regexp"
|
||||
|
||||
// TokenizerType identifies the tokenization algorithm
|
||||
type TokenizerType int
|
||||
|
||||
const (
|
||||
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
|
||||
TokenizerSentencePiece // SentencePiece with ▁ for spaces
|
||||
)
|
||||
|
||||
// Vocabulary holds the tokenizer vocabulary and merges
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Reverse map[string]int32
|
||||
Merges map[string]int
|
||||
|
||||
BOS int32
|
||||
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
|
||||
PAD int32 // Padding token (often <|endoftext|> or <pad>)
|
||||
AddBOS bool
|
||||
AddEOS bool
|
||||
|
||||
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
|
||||
byteTokens [256]int32
|
||||
}
|
||||
|
||||
// Tokenizer handles BPE and SentencePiece tokenization
|
||||
type Tokenizer struct {
|
||||
vocab *Vocabulary
|
||||
pretokenizer *regexp.Regexp
|
||||
specialTokens map[string]int32 // Special tokens for direct lookup
|
||||
sortedSpecialTokens []string // Special tokens sorted by length, longest first
|
||||
typ TokenizerType // Algorithm type
|
||||
}
|
||||
|
||||
// Precomputed GPT-2 byte-level encoding table
|
||||
// Maps byte values to their encoded rune equivalents
|
||||
var byteToRune [256]rune
|
||||
|
||||
func init() {
|
||||
for b := 0; b < 256; b++ {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
byteToRune[b] = r
|
||||
}
|
||||
}
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (t *Tokenizer) VocabSize() int {
|
||||
return len(t.vocab.Values)
|
||||
}
|
||||
|
||||
// BOS returns the beginning of sequence token ID
|
||||
func (t *Tokenizer) BOS() int32 {
|
||||
return t.vocab.BOS
|
||||
}
|
||||
|
||||
// EOS returns the first end of sequence token ID (for backwards compatibility)
|
||||
func (t *Tokenizer) EOS() int32 {
|
||||
if len(t.vocab.EOS) > 0 {
|
||||
return t.vocab.EOS[0]
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// EOSTokens returns all end of sequence token IDs
|
||||
func (t *Tokenizer) EOSTokens() []int32 {
|
||||
return t.vocab.EOS
|
||||
}
|
||||
|
||||
// PAD returns the padding token ID, or -1 if not set
|
||||
func (t *Tokenizer) PAD() int32 {
|
||||
return t.vocab.PAD
|
||||
}
|
||||
|
||||
// IsEOS returns true if the token ID is an end of sequence token
|
||||
func (t *Tokenizer) IsEOS(id int32) bool {
|
||||
for _, eos := range t.vocab.EOS {
|
||||
if id == eos {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSpecialToken returns the token ID for a special token string
|
||||
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
||||
id, ok := t.specialTokens[name]
|
||||
return id, ok
|
||||
}
|
||||
251
x/tokenizer/tokenizer_benchmark_test.go
Normal file
251
x/tokenizer/tokenizer_benchmark_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkSinkIDs []int32
|
||||
benchmarkSinkStr string
|
||||
benchmarkSinkTok *Tokenizer
|
||||
)
|
||||
|
||||
const benchmarkWordPieceJSON = `{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {
|
||||
"[UNK]": 0,
|
||||
"hello": 1,
|
||||
"##world": 2,
|
||||
"##ly": 3,
|
||||
"##hello": 4
|
||||
}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
const benchmarkSentencePieceJSON = `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"\u2581": 0,
|
||||
"h": 1,
|
||||
"e": 2,
|
||||
"l": 3,
|
||||
"o": 4,
|
||||
"w": 5,
|
||||
"r": 6,
|
||||
"d": 7,
|
||||
"<0x0A>": 8
|
||||
},
|
||||
"merges": []
|
||||
},
|
||||
"decoder": {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{
|
||||
"type": "Replace",
|
||||
"pattern": {
|
||||
"String": "\u2581"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
func benchmarkMiniLlamaPath(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve benchmark file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
data := benchmarkLoadMiniLlamaBytes(tb)
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
|
||||
tb.Helper()
|
||||
|
||||
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "short", text: "Hello, world!"},
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
|
||||
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
|
||||
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input.text, false)
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
|
||||
data := benchmarkLoadMiniLlamaBytes(b)
|
||||
|
||||
config := &TokenizerConfig{
|
||||
TokenizerConfigJSON: []byte(`{
|
||||
"bos_token": {"content": "<|begin_of_text|>"},
|
||||
"eos_token": {"content": "<|end_of_text|>"},
|
||||
"add_bos_token": true
|
||||
}`),
|
||||
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
|
||||
}
|
||||
|
||||
b.Run("without_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytes failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("with_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytesWithConfig(data, config)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
175
x/tokenizer/tokenizer_bpe.go
Normal file
175
x/tokenizer/tokenizer_bpe.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "container/heap"
|
||||
|
||||
type bpeMergeNode struct {
|
||||
prev int
|
||||
next int
|
||||
token string
|
||||
}
|
||||
|
||||
type bpePair struct {
|
||||
left int
|
||||
right int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type bpePairHeap []*bpePair
|
||||
|
||||
func (h bpePairHeap) Len() int { return len(h) }
|
||||
|
||||
func (h bpePairHeap) Less(i, j int) bool {
|
||||
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
|
||||
}
|
||||
|
||||
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *bpePairHeap) Push(x any) {
|
||||
*h = append(*h, x.(*bpePair))
|
||||
}
|
||||
|
||||
func (h *bpePairHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// encodeBPEMerge encodes using BPE merge algorithm.
|
||||
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
|
||||
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
|
||||
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
|
||||
runes := []rune(encoded)
|
||||
if len(runes) == 0 {
|
||||
return ids
|
||||
}
|
||||
|
||||
nodes := make([]bpeMergeNode, len(runes))
|
||||
for i := range runes {
|
||||
nodes[i] = bpeMergeNode{
|
||||
prev: i - 1,
|
||||
next: i + 1,
|
||||
token: string(runes[i]),
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(left, right int) *bpePair {
|
||||
if left < 0 || right >= len(nodes) {
|
||||
return nil
|
||||
}
|
||||
if nodes[left].token == "" || nodes[right].token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
leftToken, rightToken := nodes[left].token, nodes[right].token
|
||||
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := leftToken + rightToken
|
||||
if _, ok := t.vocab.Reverse[value]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &bpePair{
|
||||
left: left,
|
||||
right: right,
|
||||
rank: rank,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := bpePairHeap{}
|
||||
heap.Init(&pairs)
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for pairs.Len() > 0 {
|
||||
pair := heap.Pop(&pairs).(*bpePair)
|
||||
left, right := nodes[pair.left], nodes[pair.right]
|
||||
if left.token == "" || right.token == "" {
|
||||
continue
|
||||
}
|
||||
if left.next != pair.right || right.prev != pair.left {
|
||||
continue
|
||||
}
|
||||
if left.token+right.token != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes[pair.left].token = pair.value
|
||||
nodes[pair.right].token = ""
|
||||
nodes[pair.left].next = right.next
|
||||
if right.next < len(nodes) {
|
||||
nodes[right.next].prev = pair.left
|
||||
}
|
||||
|
||||
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.token == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if id, ok := t.vocab.Reverse[node.token]; ok {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
ids = t.appendByteFallback(ids, node.token)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
|
||||
if t.typ == TokenizerBPE {
|
||||
for _, r := range token {
|
||||
if b, ok := decodeByteLevelRune(r); ok {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
|
||||
for _, b := range []byte(token) {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func decodeByteLevelRune(r rune) (byte, bool) {
|
||||
switch {
|
||||
case r >= 0x00 && r <= 0xFF:
|
||||
return byte(r), true
|
||||
case r == 0x0100:
|
||||
return 0x00, true
|
||||
case r == 0x0143:
|
||||
return 0x00ad, true
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
return byte(r - 0x0100), true
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
return byte(r - 0x00a2), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
137
x/tokenizer/tokenizer_correctness_test.go
Normal file
137
x/tokenizer/tokenizer_correctness_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func equalIDs(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestEncodeRoundtripMiniLlama(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
inputs := []string{
|
||||
"",
|
||||
"hello",
|
||||
"hello world",
|
||||
" hello world ",
|
||||
"don't we'll they're",
|
||||
"1234567890",
|
||||
"こんにちは世界",
|
||||
"Hello 世界",
|
||||
"func main() {}",
|
||||
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
|
||||
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
// Simulate construction outside loader path where cache is not set.
|
||||
tok.sortedSpecialTokens = nil
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
|
||||
|
||||
prev := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(prev)
|
||||
|
||||
runtime.GOMAXPROCS(1)
|
||||
seq := tok.Encode(input, false)
|
||||
|
||||
if prev < 2 {
|
||||
runtime.GOMAXPROCS(2)
|
||||
} else {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
}
|
||||
par := tok.Encode(input, false)
|
||||
|
||||
if !equalIDs(seq, par) {
|
||||
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
|
||||
}
|
||||
}
|
||||
56
x/tokenizer/tokenizer_decode.go
Normal file
56
x/tokenizer/tokenizer_decode.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Decode converts token IDs back to text
|
||||
func (t *Tokenizer) Decode(ids []int32) string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, id := range ids {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
continue
|
||||
}
|
||||
|
||||
token := t.vocab.Values[id]
|
||||
|
||||
switch t.typ {
|
||||
case TokenizerSentencePiece:
|
||||
// SentencePiece style: replace ▁ with space, decode byte tokens
|
||||
token = strings.ReplaceAll(token, "▁", " ")
|
||||
// Handle byte fallback tokens like <0x0D>
|
||||
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
|
||||
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
|
||||
sb.WriteByte(byte(v))
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString(token)
|
||||
default:
|
||||
// GPT-2 BPE style: decode byte-level encoding
|
||||
for _, r := range token {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// Mirror GGML tokenizer behavior for NULL byte.
|
||||
// 0x00 is omitted during decode.
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// Write as byte, not UTF-8 encoded rune
|
||||
sb.WriteByte(byte(r))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
289
x/tokenizer/tokenizer_encode.go
Normal file
289
x/tokenizer/tokenizer_encode.go
Normal file
@@ -0,0 +1,289 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
encodeParallelMinInputBytes = 4 * 1024
|
||||
encodeParallelMinChunksPerWorker = 8
|
||||
)
|
||||
|
||||
type tokenMatch struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
type encodeChunk struct {
|
||||
text string
|
||||
isSpecial bool
|
||||
}
|
||||
|
||||
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
|
||||
func isNonNewlineWhitespace(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '\n' || r == '\r' {
|
||||
return false
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
|
||||
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
|
||||
if len(t.specialTokens) == 0 {
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
tokens := t.sortedSpecialTokens
|
||||
if len(tokens) == 0 {
|
||||
// Fallback for tokenizers constructed outside the loaders.
|
||||
tokens = make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
}
|
||||
|
||||
var result []string
|
||||
remaining := s
|
||||
|
||||
for len(remaining) > 0 {
|
||||
found := false
|
||||
for _, tok := range tokens {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
result = append(result, tok)
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Find next special token position
|
||||
nextPos := len(remaining)
|
||||
for _, tok := range tokens {
|
||||
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
|
||||
nextPos = idx
|
||||
}
|
||||
}
|
||||
if nextPos > 0 {
|
||||
result = append(result, remaining[:nextPos])
|
||||
}
|
||||
remaining = remaining[nextPos:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
|
||||
m := part[curr.start:curr.end]
|
||||
nextText := part[next.start:next.end]
|
||||
|
||||
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
firstRune, _ := utf8.DecodeRuneInString(nextText)
|
||||
if !unicode.IsLetter(firstRune) {
|
||||
return
|
||||
}
|
||||
|
||||
lastSpaceStart := curr.end
|
||||
for j := curr.end; j > curr.start; {
|
||||
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
|
||||
if unicode.IsSpace(r) {
|
||||
lastSpaceStart = j - size
|
||||
break
|
||||
}
|
||||
j -= size
|
||||
}
|
||||
if lastSpaceStart > curr.start {
|
||||
curr.end = lastSpaceStart
|
||||
next.start = lastSpaceStart
|
||||
} else {
|
||||
next.start = curr.start
|
||||
curr.end = curr.start
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
|
||||
if _, ok := t.specialTokens[part]; ok {
|
||||
fn(encodeChunk{text: part, isSpecial: true})
|
||||
return
|
||||
}
|
||||
|
||||
if t.pretokenizer == nil {
|
||||
fn(encodeChunk{text: part, isSpecial: false})
|
||||
return
|
||||
}
|
||||
|
||||
re := t.pretokenizer
|
||||
offset := 0
|
||||
loc := re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
for {
|
||||
loc = re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
adjustWhitespaceBoundary(part, &curr, &next)
|
||||
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
curr = next
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
|
||||
if c.isSpecial {
|
||||
if id, ok := t.specialTokens[c.text]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
return t.encodeChunkInto(c.text, ids)
|
||||
}
|
||||
|
||||
// Encode tokenizes text to token IDs.
|
||||
// Parallel encoding is used only for very large inputs with enough chunks per worker.
|
||||
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
|
||||
// First: split by special tokens
|
||||
parts := t.splitBySpecialTokens(s)
|
||||
|
||||
// Fast path: encode sequentially without materializing chunk slices.
|
||||
if len(s) < encodeParallelMinInputBytes {
|
||||
var ids []int32
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
})
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// For large inputs collect chunks to enable parallel processing.
|
||||
var allChunks []encodeChunk
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
allChunks = append(allChunks, c)
|
||||
})
|
||||
}
|
||||
|
||||
// Encode chunks. Use the parallel path only when the chunk count is
|
||||
// large enough to amortize goroutine/synchronization overhead.
|
||||
useParallel := true
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
if numWorkers > len(allChunks) {
|
||||
numWorkers = len(allChunks)
|
||||
}
|
||||
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
|
||||
useParallel = false
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
if !useParallel {
|
||||
for _, c := range allChunks {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
}
|
||||
} else {
|
||||
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
|
||||
results := make([][]int32, numWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
start := i * chunksPer
|
||||
end := start + chunksPer
|
||||
if end > len(allChunks) {
|
||||
end = len(allChunks)
|
||||
}
|
||||
if start >= end {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(i int, chunks []encodeChunk) {
|
||||
defer wg.Done()
|
||||
var r []int32
|
||||
for _, c := range chunks {
|
||||
r = t.appendEncodedChunk(r, c)
|
||||
}
|
||||
results[i] = r
|
||||
}(i, allChunks[start:end])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, r := range results {
|
||||
ids = append(ids, r...)
|
||||
}
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
|
||||
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
|
||||
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
|
||||
if s == "" {
|
||||
return ids
|
||||
}
|
||||
|
||||
// Apply encoding transformation
|
||||
// SentencePiece: replace space with ▁
|
||||
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
|
||||
var encoded string
|
||||
if t.typ == TokenizerSentencePiece {
|
||||
encoded = strings.ReplaceAll(s, " ", "▁")
|
||||
} else {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(s) * 2)
|
||||
for i := 0; i < len(s); i++ {
|
||||
sb.WriteRune(byteToRune[s[i]])
|
||||
}
|
||||
encoded = sb.String()
|
||||
}
|
||||
|
||||
// Fast path: check if entire chunk is a single token
|
||||
if id, ok := t.vocab.Reverse[encoded]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
|
||||
return t.encodeBPEMerge(encoded, ids)
|
||||
}
|
||||
207
x/tokenizer/tokenizer_ggml_parity_test.go
Normal file
207
x/tokenizer/tokenizer_ggml_parity_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func llama32GGMLFixturePath(tb testing.TB, file string) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve test file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
|
||||
}
|
||||
|
||||
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open encoder.json: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
tb.Fatalf("failed to decode encoder.json: %v", err)
|
||||
}
|
||||
|
||||
type addedToken struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
}
|
||||
var addedTokens []addedToken
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
id := int32(len(vocab))
|
||||
vocab[token] = id
|
||||
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open vocab.bpe: %v", err)
|
||||
}
|
||||
defer mf.Close()
|
||||
|
||||
var merges []string
|
||||
scanner := bufio.NewScanner(mf)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
merges = append(merges, line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
tb.Fatalf("failed to read vocab.bpe: %v", err)
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
AddedTokens []addedToken `json:"added_tokens"`
|
||||
}{}
|
||||
|
||||
payload.Model.Type = "BPE"
|
||||
payload.Model.Vocab = vocab
|
||||
payload.Model.Merges = merges
|
||||
payload.PreTokenizer.Type = "Sequence"
|
||||
payload.PreTokenizer.Pretokenizers = []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}{
|
||||
{
|
||||
Type: "Split",
|
||||
Pattern: struct {
|
||||
Regex string `json:"Regex"`
|
||||
}{
|
||||
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload.AddedTokens = addedTokens
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func TestGGMLLlamaKnownEncodings(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[string][]int32{
|
||||
"hello world": {15339, 1917},
|
||||
"hello <|end_of_text|>": {15339, 220, 128001},
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[int][]int32{
|
||||
1: {15},
|
||||
2: {410},
|
||||
3: {931},
|
||||
4: {931, 15},
|
||||
5: {931, 410},
|
||||
6: {931, 931},
|
||||
7: {931, 931, 15},
|
||||
8: {931, 931, 410},
|
||||
9: {931, 931, 931},
|
||||
10: {931, 931, 931, 15},
|
||||
11: {931, 931, 931, 410},
|
||||
12: {931, 931, 931, 931},
|
||||
13: {931, 931, 931, 931, 15},
|
||||
14: {931, 931, 931, 931, 410},
|
||||
15: {931, 931, 931, 931, 931},
|
||||
16: {931, 931, 931, 931, 931, 15},
|
||||
17: {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for n, want := range cases {
|
||||
input := strings.Repeat("0", n)
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, input := range cases {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
|
||||
ids := tok.Encode(string(rune(0x00)), false)
|
||||
got := tok.Decode(ids)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
|
||||
}
|
||||
}
|
||||
458
x/tokenizer/tokenizer_load.go
Normal file
458
x/tokenizer/tokenizer_load.go
Normal file
@@ -0,0 +1,458 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
||||
type TokenizerConfig struct {
|
||||
TokenizerConfigJSON []byte // tokenizer_config.json content
|
||||
GenerationConfigJSON []byte // generation_config.json content
|
||||
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
||||
ConfigJSON []byte // config.json content
|
||||
}
|
||||
|
||||
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
||||
// This is useful when loading from blob storage where the file content is already in memory.
|
||||
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
||||
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
||||
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
||||
return loadFromTokenizerJSON(data)
|
||||
}
|
||||
|
||||
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
||||
// This is useful when loading from blob storage where companion config files are also blobs.
|
||||
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
||||
t, err := loadFromTokenizerJSON(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Apply special token configs from provided data
|
||||
loadSpecialTokenConfigFromBytes(t, config)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
|
||||
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
|
||||
|
||||
var raw struct {
|
||||
Model struct {
|
||||
Type string `json:"type"` // "BPE"
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
|
||||
} `json:"model"`
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Covers SentencePiece and BPE models
|
||||
if raw.Model.Type != "BPE" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
|
||||
}
|
||||
|
||||
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
|
||||
var mergesStrings []string
|
||||
if raw.Model.Merges != nil {
|
||||
var mergesArrays [][]string
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
|
||||
// Try array of arrays format
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse merges: %w", err)
|
||||
}
|
||||
// Convert [][]string to []string
|
||||
mergesStrings = make([]string, len(mergesArrays))
|
||||
for i, pair := range mergesArrays {
|
||||
if len(pair) != 2 {
|
||||
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
|
||||
}
|
||||
mergesStrings[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build tokenizer
|
||||
t := &Tokenizer{
|
||||
vocab: &Vocabulary{
|
||||
Values: make([]string, len(raw.Model.Vocab)),
|
||||
Reverse: raw.Model.Vocab,
|
||||
Merges: make(map[string]int, len(mergesStrings)),
|
||||
BOS: -1,
|
||||
PAD: -1,
|
||||
},
|
||||
specialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Build values array
|
||||
for token, id := range raw.Model.Vocab {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, id+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[id] = token
|
||||
}
|
||||
|
||||
// Build merges map
|
||||
for i, merge := range mergesStrings {
|
||||
t.vocab.Merges[merge] = i
|
||||
}
|
||||
|
||||
// Add all added_tokens to vocabulary and special tokens map.
|
||||
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
||||
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
||||
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
||||
// to treat all added_tokens as special to match HuggingFace behavior.
|
||||
for _, tok := range raw.AddedTokens {
|
||||
if int(tok.ID) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, tok.ID+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[tok.ID] = tok.Content
|
||||
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
||||
}
|
||||
|
||||
// Precompute byte token IDs for <0xNN> fallback
|
||||
initByteTokens(t)
|
||||
|
||||
// Determine tokenizer type
|
||||
switch {
|
||||
case detectSentencePiece(raw.Decoder):
|
||||
t.typ = TokenizerSentencePiece
|
||||
default:
|
||||
t.typ = TokenizerBPE
|
||||
}
|
||||
|
||||
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
|
||||
if t.typ == TokenizerBPE {
|
||||
pattern := extractPretokenizer(raw.PreTokenizer)
|
||||
if pattern == "" {
|
||||
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
}
|
||||
re, err := regexp.Compile(rewritePatternForRE2(pattern))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
|
||||
}
|
||||
t.pretokenizer = re
|
||||
}
|
||||
|
||||
cacheSortedSpecialTokens(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func cacheSortedSpecialTokens(t *Tokenizer) {
|
||||
if len(t.specialTokens) == 0 {
|
||||
t.sortedSpecialTokens = nil
|
||||
return
|
||||
}
|
||||
|
||||
tokens := make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
t.sortedSpecialTokens = tokens
|
||||
}
|
||||
|
||||
type specialTokenConfigData struct {
|
||||
tokenizerConfigJSON []byte
|
||||
generationConfigJSON []byte
|
||||
specialTokensMapJSON []byte
|
||||
configJSON []byte
|
||||
}
|
||||
|
||||
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
|
||||
parseTokenIDs := func(v interface{}) []int32 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return []int32{int32(val)}
|
||||
case []interface{}:
|
||||
ids := make([]int32, 0, len(val))
|
||||
for _, id := range val {
|
||||
if f, ok := id.(float64); ok {
|
||||
ids = append(ids, int32(f))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: generation_config.json
|
||||
if len(config.generationConfigJSON) > 0 {
|
||||
var genConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
|
||||
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: config.json
|
||||
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
||||
var modelConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
}
|
||||
if t.vocab.BOS < 0 {
|
||||
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: tokenizer_config.json
|
||||
if len(config.tokenizerConfigJSON) > 0 {
|
||||
var tokConfig struct {
|
||||
BOSToken interface{} `json:"bos_token"`
|
||||
EOSToken interface{} `json:"eos_token"`
|
||||
PADToken interface{} `json:"pad_token"`
|
||||
AddBOSToken *bool `json:"add_bos_token"`
|
||||
AddEOSToken *bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if tokConfig.AddBOSToken != nil {
|
||||
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
||||
}
|
||||
if tokConfig.AddEOSToken != nil {
|
||||
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: special_tokens_map.json
|
||||
if len(config.specialTokensMapJSON) > 0 {
|
||||
var tokensMap map[string]interface{}
|
||||
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
|
||||
// Tokens can be represented as:
|
||||
// - string: "token"
|
||||
// - object: {"content": "token", ...}
|
||||
func extractTokenString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
// Direct string
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
// Object with content field
|
||||
if m, ok := v.(map[string]interface{}); ok {
|
||||
if content, ok := m["content"].(string); ok {
|
||||
return content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
|
||||
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
|
||||
// - (?!\S) negative lookahead - RE2 doesn't support this
|
||||
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
|
||||
//
|
||||
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
|
||||
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
|
||||
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
|
||||
func rewritePatternForRE2(pattern string) string {
|
||||
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
|
||||
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
|
||||
|
||||
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
|
||||
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
|
||||
|
||||
// Expand case-insensitive contraction pattern to explicit alternations
|
||||
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
|
||||
|
||||
return pattern
|
||||
}
|
||||
|
||||
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
||||
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
||||
applySpecialTokenConfig(t, specialTokenConfigData{
|
||||
tokenizerConfigJSON: config.TokenizerConfigJSON,
|
||||
generationConfigJSON: config.GenerationConfigJSON,
|
||||
specialTokensMapJSON: config.SpecialTokensMapJSON,
|
||||
configJSON: config.ConfigJSON,
|
||||
})
|
||||
}
|
||||
|
||||
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
|
||||
// vs GPT-2 byte-level encoding
|
||||
func detectSentencePiece(data json.RawMessage) bool {
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Sequence decoder with Replace step (SentencePiece style)
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil {
|
||||
if seq.Type == "Sequence" {
|
||||
for _, dec := range seq.Decoders {
|
||||
// Look for Replace decoder that converts ▁ to space
|
||||
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for direct ByteLevel decoder (GPT-2 style)
|
||||
var simple struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &simple); err == nil {
|
||||
if simple.Type == "ByteLevel" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
|
||||
func initByteTokens(t *Tokenizer) {
|
||||
for i := range t.vocab.byteTokens {
|
||||
t.vocab.byteTokens[i] = -1
|
||||
}
|
||||
for b := 0; b < 256; b++ {
|
||||
token := fmt.Sprintf("<0x%02X>", b)
|
||||
if id, ok := t.vocab.Reverse[token]; ok {
|
||||
t.vocab.byteTokens[b] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
|
||||
func extractPretokenizer(data json.RawMessage) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to parse as a single Split pretokenizer
|
||||
var single struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
|
||||
return single.Pattern.Regex
|
||||
}
|
||||
|
||||
// Try to parse as Sequence of pretokenizers - use first Split pattern
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
|
||||
for _, pt := range seq.Pretokenizers {
|
||||
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||
return pt.Pattern.Regex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
26
x/tokenizer/tokenizer_load_test.go
Normal file
26
x/tokenizer/tokenizer_load_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {"[UNK]": 0, "hello": 1}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`)
|
||||
|
||||
_, err := LoadFromBytes(data)
|
||||
if err == nil {
|
||||
t.Fatal("expected WordPiece load to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user