Compare commits

...

19 Commits

Author SHA1 Message Date
Jesse Gross
7a0f609241 mlxrunner: Simplify KV cache to single-entry prefix matching
The KV cache previously used a tree structure which could
store multiple divergent sequences, which is good for cache
reuse. However, this is typically used in conjunction with
paged attention so each node in the tree can store just a
chunk of the KV cache and they can be stitched together later.
We don't currently do this, so the cache was storing copies of
the full cache for each past sequence.

This redundancy plus the lack of resource limits, caused significant
memory use as a conversation grew. Instead, this changes to store
a single entry for the cache, which can be prefix matched. Although
it is less ideal for multiple users, it largely matches Ollama's
current behavior. It can be improved as additional pieces are fleshed
out.
2026-02-20 17:01:45 -08:00
Jesse Gross
35c89dc9fd mlxrunner: Fix memory leaks with pin/sweep lifecycle management
The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
2026-02-20 16:43:19 -08:00
Patrick Devine
97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00
natl-set
458dd1b9d9 mlx: try loading library via rpath before searching directories (#14322)
The existing code manually searches directories for libmlxc.* and passes
full paths to dlopen, bypassing the binary's rpath. This means MLX
libraries installed via package managers (e.g., Homebrew) aren't found
even when rpath is correctly set at link time.

This change adds a fallback that tries loading via rpath first (using
just the library name), before falling back to the existing directory
search. This follows standard Unix/macOS conventions and works with any
installation that sets rpath.

Fixes library loading on macOS with Homebrew-installed mlx-c without
requiring OLLAMA_LIBRARY_PATH environment variable.

Co-authored-by: Natl <nat@MacBook-Pro.local>
2026-02-19 10:55:02 -08:00
Bruce MacDonald
9d02d1d767 install: prevent partial download script execution (#14311)
Wrap script in main function so that a truncated partial download doesn't end up executing half a script.
2026-02-18 18:32:45 -08:00
Bruce MacDonald
1a636fb47a cmd: set codex env vars on launch and handle zstd request bodies (#14122)
The Codex runner was not setting OPENAI_BASE_URL or OPENAI_API_KEY, this prevents Codex from sending requests to api.openai.com instead of the local Ollama server. This mirrors the approach used by the Claude runner.

Codex v0.98.0 sends zstd-compressed request bodies to the /v1/responses endpoint. Add decompression support in ResponsesMiddleware with an 8MB max decompressed size limit to prevent resource exhaustion.
2026-02-18 17:19:36 -08:00
Patrick Devine
0759fface9 Revert "chore: update mlx-c bindings to 0.5.0 (#14303)" (#14316)
This reverts commit f01a9a7859.
2026-02-18 17:01:25 -08:00
Parth Sareen
325b72bc31 cmd/tui: default to single-select for editor integrations (#14302) 2026-02-17 18:17:27 -08:00
Patrick Devine
f01a9a7859 chore: update mlx-c bindings to 0.5.0 (#14303) 2026-02-17 16:48:16 -08:00
Patrick Devine
9aefd2dfee model: add qwen3 support to mlxrunner (#14293) 2026-02-17 13:58:49 -08:00
Patrick Devine
d07e4a1dd3 bugfix: better mlx model scheduling (#14290)
This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
2026-02-17 13:57:05 -08:00
Parth Sareen
8a257ec00a docs: make integrations more discoverable (#14301)
* docs: add Pi integration page

* docs: flatten integration sidebar with expanded subheadings

* docs: add OpenClaw and Claude Code to quickstart
2026-02-17 13:27:25 -08:00
Parth Sareen
2f4de1acf7 cmd: ollama launch always show model picker (#14299) 2026-02-17 12:02:14 -08:00
Parth Sareen
ec95c45f70 cmd/config: ollama launch cline CLI (#14294) 2026-02-17 11:37:53 -08:00
Patrick Devine
3a88f7eb20 bugfix: add missing linear layer factory (#14289) 2026-02-16 17:22:20 -08:00
Patrick Devine
0d5da826d4 bugfix: display the parameter count correctly in mlx for ollama show (#14285) 2026-02-16 13:03:34 -08:00
Patrick Devine
9b795698b8 model: add llama3 architecture to mlxrunner (#14277) 2026-02-15 23:06:28 -08:00
Patrick Devine
041fb77639 model: add gemma3 to the mlxrunner (#14276)
This change adds the gemma3 model to the mlxrunner and simplifies some of the quantization
code for loading weights.
2026-02-15 22:47:59 -08:00
Saumil Shah
8224cce583 readme: update download link for macOS (#1) (#14271) 2026-02-15 15:25:15 -08:00
64 changed files with 5663 additions and 467 deletions

View File

@@ -16,7 +16,7 @@ Start building with open models.
curl -fsSL https://ollama.com/install.sh | sh
```
or [download manually](http://localhost:8080/download/Ollama.dmg)
or [download manually](https://ollama.com/download/Ollama.dmg)
### Windows

View File

@@ -57,9 +57,9 @@ import (
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems)
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
@@ -1897,9 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems)
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}

View File

@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items)
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}

123
cmd/config/cline.go Normal file
View File

@@ -0,0 +1,123 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/envconfig"
)
// Cline implements Runner and Editor for the Cline CLI integration
type Cline struct{}
func (c *Cline) String() string { return "Cline" }
func (c *Cline) Run(model string, args []string) error {
if _, err := exec.LookPath("cline"); err != nil {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *Cline) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".cline", "data", "globalState.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Cline) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
}
}
// Set Ollama as the provider for both act and plan modes
baseURL := envconfig.Host().String()
config["ollamaBaseUrl"] = baseURL
config["actModeApiProvider"] = "ollama"
config["actModeOllamaModelId"] = models[0]
config["actModeOllamaBaseUrl"] = baseURL
config["planModeApiProvider"] = "ollama"
config["planModeOllamaModelId"] = models[0]
config["planModeOllamaBaseUrl"] = baseURL
config["welcomeViewCompleted"] = true
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Cline) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}
if config["actModeApiProvider"] != "ollama" {
return nil
}
modelID, _ := config["actModeOllamaModelId"].(string)
if modelID == "" {
return nil
}
return []string{modelID}
}

204
cmd/config/cline_test.go Normal file
View File

@@ -0,0 +1,204 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestClineIntegration(t *testing.T) {
c := &Cline{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Cline" {
t.Errorf("String() = %q, want %q", got, "Cline")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClineEdit(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var config map[string]any
json.Unmarshal(data, &config)
return config
}
t.Run("creates config from scratch", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeApiProvider"] != "ollama" {
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
}
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
}
if config["planModeApiProvider"] != "ollama" {
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
}
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
}
if config["welcomeViewCompleted"] != true {
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
}
})
t.Run("preserves existing fields", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
os.MkdirAll(configDir, 0o755)
existing := map[string]any{
"remoteRulesToggles": map[string]any{},
"remoteWorkflowToggles": map[string]any{},
"customSetting": "keep-me",
}
data, _ := json.Marshal(existing)
os.WriteFile(configPath, data, 0o644)
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["customSetting"] != "keep-me" {
t.Errorf("customSetting was not preserved")
}
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
})
t.Run("updates model on re-edit", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
if config["planModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(nil); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Error("expected no config file to be created for empty models")
}
})
t.Run("uses first model as primary", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
}
})
}
func TestClineModels(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
t.Run("returns nil when no config", func(t *testing.T) {
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "anthropic",
"actModeOllamaModelId": "some-model",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns model when ollama is configured", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "ollama",
"actModeOllamaModelId": "kimi-k2.5:cloud",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
}
})
}
func TestClinePaths(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns nil when no config exists", func(t *testing.T) {
if paths := c.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".cline", "data")
os.MkdirAll(configDir, 0o755)
configPath := filepath.Join(configDir, "globalState.json")
os.WriteFile(configPath, []byte("{}"), 0o644)
paths := c.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"maps"
"net/http"
"os"
"os/exec"
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Openclaw{},
"cline": &Cline{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
// integrationInstallHints maps integration names to install URLs.
var integrationInstallHints = map[string]string{
"claude": "https://code.claude.com/docs/en/quickstart",
"cline": "https://cline.bot/cli",
"openclaw": "https://docs.openclaw.ai",
"codex": "https://developers.openai.com/codex/cli/",
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
"opencode": "https://opencode.ai",
"pi": "https://github.com/badlogic/pi-mono",
}
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
// integrationDescriptions maps integration names to short descriptions.
var integrationDescriptions = map[string]string{
"claude": "Anthropic's coding tool with subagents",
"cline": "Autonomous coding agent with parallel execution",
"codex": "OpenAI's open-source coding agent",
"openclaw": "Personal AI with 100+ skills",
"droid": "Factory's coding agent across terminal and IDEs",
"opencode": "Anomaly's open-source coding agent",
"pi": "Minimal AI agent toolkit with plugin support",
}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
// integrationOrder defines a custom display order for integrations.
// Integrations listed here are placed at the end in the given order;
// all others appear first, sorted alphabetically.
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
// with integrationOrder entries placed at the end.
func ListIntegrationInfos() []IntegrationInfo {
var result []IntegrationInfo
for name, r := range integrations {
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
Description: integrationDescriptions[name],
})
}
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
}
slices.SortFunc(result, func(a, b IntegrationInfo) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
// Both have custom order: sort by their rank
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
// Only one has custom order: it goes last
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
// Neither has custom order: alphabetical
return strings.Compare(a.Name, b.Name)
})
return result
@@ -186,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "cline":
_, err := exec.LookPath("cline")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
case "pi":
_, err := exec.LookPath("pi")
return err == nil
default:
return true // Assume installed for unknown integrations
}
@@ -214,7 +248,8 @@ type ModelItem struct {
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
@@ -257,7 +292,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
selected, err := selector("Select model to run:", items)
selected, err := selector("Select model to run:", items, "")
if err != nil {
return "", err
}
@@ -367,13 +402,11 @@ func selectIntegration() (string, error) {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []ModelItem
for _, name := range names {
for name, r := range integrations {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
@@ -381,7 +414,25 @@ func selectIntegration() (string, error) {
items = append(items, ModelItem{Name: name, Description: description})
}
return DefaultSingleSelector("Select integration:", items)
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(items, func(a, b ModelItem) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return DefaultSingleSelector("Select integration:", items, "")
}
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
@@ -439,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := single(prompt, items)
model, err := single(prompt, items, current)
if err != nil {
return nil, err
}
@@ -812,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
cline Cline
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
Examples:
ollama launch
@@ -915,11 +968,9 @@ Examples:
}
// Validate saved model still exists
cloudCleared := false
if model != "" && modelFlag == "" {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
model = ""
cloudCleared = true
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
@@ -928,18 +979,16 @@ Examples:
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Show picker so user can change model (skip when --model flag provided)
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
@@ -1001,27 +1050,13 @@ Examples:
return err
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
savedModels := filterDisabledCloudModels(saved.Models)
if len(savedModels) != len(saved.Models) {
_ = SaveIntegration(name, savedModels)
}
if len(savedModels) == 0 {
// All saved models were cloud — fall through to picker
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
} else {
models = savedModels
return runIntegration(name, models[0], passArgs)
}
} else {
current := ""
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
current = saved.Models[0]
}
var err error
models, err = selectModels(cmd.Context(), name, "")
models, err = selectModels(cmd.Context(), name, current)
if errors.Is(err, errCancelled) {
return nil
}

View File

@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
}
})
t.Run("sorted by name", func(t *testing.T) {
t.Run("sorted with custom order at end", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order.
// All other entries should be sorted alphabetically before them.
orderRank := make(map[string]int)
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
for i := 1; i < len(infos); i++ {
if infos[i-1].Name >= infos[i].Name {
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
}
}
})

View File

@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
return s
}
func SelectSingle(title string, items []SelectItem) (string, error) {
// cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current != "" {
for i, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
}
return 0
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
p := tea.NewProgram(m)
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if len(m.checkOrder) > 0 {
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
m.toggleItem()
if m.multi {
m.toggleItem()
}
case tea.KeyUp:
if m.cursor > 0 {
@@ -579,7 +619,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
m.toggleItem()
if m.multi {
m.toggleItem()
}
} else {
m.filter += string(msg.Runes)
m.cursor = 0
@@ -591,6 +633,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
@@ -602,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
@@ -624,6 +679,11 @@ func (m multiSelectorModel) View() string {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
@@ -648,7 +708,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(filtered) {
break
}
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
@@ -671,7 +731,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
}
@@ -691,7 +751,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(otherItems) {
break
}
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
@@ -703,15 +763,18 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n")
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
@@ -734,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
if fm.cancelled || !fm.confirmed {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
}
var result []string
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
if idx != last {
result = append(result, fm.items[idx].Name)
}
}
return result, nil
}

View File

@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
}
}
// --- cursorForCurrent ---
func TestCursorForCurrent(t *testing.T) {
testItems := []SelectItem{
{Name: "llama3.2", Recommended: true},
{Name: "qwen3:8b", Recommended: true},
{Name: "gemma3:latest"},
{Name: "deepseek-r1"},
{Name: "glm-5:cloud"},
}
tests := []struct {
name string
current string
want int
}{
{"empty current", "", 0},
{"exact match", "qwen3:8b", 1},
{"no match returns 0", "nonexistent", 0},
{"bare name matches with :latest suffix", "gemma3", 2},
{"full tag matches bare item", "llama3.2:latest", 0},
{"cloud model exact match", "glm-5:cloud", 4},
{"cloud model bare name", "glm-5", 4},
{"recommended item exact match", "llama3.2", 0},
{"recommended item with tag", "qwen3", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
}
})
}
}
// --- ReorderItems ---
func TestReorderItems(t *testing.T) {
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
func TestMultiView_CheckedItemShowsX(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m.multi = true
content := m.View()
if !strings.Contains(content, "[x]") {
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
}
func TestMultiView_DefaultTag(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
m.multi = true
content := m.View()
if !strings.Contains(content, "(default)") {
t.Error("first checked item should have (default) tag")
t.Error("should have (default) tag")
}
// preChecked[0] ("a") should be the default (last in checkOrder)
aIdx := strings.Index(content, "a")
defaultIdx := strings.Index(content, "(default)")
if defaultIdx < aIdx {
t.Error("(default) tag should appear after 'a' (the current default)")
}
}
@@ -549,6 +593,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeySpace
@@ -565,6 +610,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
@@ -582,6 +628,161 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
}
}
// --- Single-add mode ---
func TestMulti_StartsInSingleMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Error("should start in single mode (multi=false)")
}
}
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
content := m.View()
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
t.Error("single mode should not show checkboxes")
}
if !strings.Contains(content, "▸") {
t.Error("single mode should show cursor indicator")
}
}
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.cursor = 1
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = updated.(multiSelectorModel)
if m.singleAdd != "b" {
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
}
if !m.confirmed {
t.Error("should set confirmed")
}
}
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space in single mode should not toggle items")
}
}
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space rune in single mode should not toggle items")
}
if m.filter != "" {
t.Error("space rune in single mode should not add to filter")
}
}
func TestMulti_TabTogglesMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Fatal("should start in single mode")
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if !m.multi {
t.Error("tab should switch to multi mode")
}
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if m.multi {
t.Error("tab should switch back to single mode")
}
}
func TestMulti_SingleModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
content := m.View()
if !strings.Contains(content, "tab add multiple") {
t.Error("single mode should show 'tab add multiple' in help")
}
}
func TestMulti_MultiModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
m.multi = true
content := m.View()
if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help")
}
}
// --- preChecked initialization order ---
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
// preChecked[0] ("a") is the current default and should end up
// last in checkOrder so it gets the (default) tag.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
if len(m.checkOrder) != 3 {
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
}
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
}
}
func TestMulti_CursorOnDefaultModel(t *testing.T) {
// preChecked[0] ("b") is the default; cursor should start on it
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
if m.cursor != 1 {
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
}
}
// --- Multi-mode last-checked is default ---
func TestMulti_LastCheckedIsDefault(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
m.multi = true
// Check "alpha" then "gamma"
m.cursor = 0
m.toggleItem()
m.cursor = 2
m.toggleItem()
// Last checked ("gamma") should be at the end of checkOrder
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "gamma" {
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
}
// The (default) tag renders based on checkOrder[len-1]
content := m.View()
if !strings.Contains(content, "(default)") {
t.Fatal("should show (default) tag")
}
// "alpha" line should NOT have the default tag
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
t.Error("'alpha' (first checked) should not have (default) tag")
}
}
}
// Key message helpers for testing
type keyType = int

View File

@@ -429,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.multiModalSelector.confirmed {
var selected []string
for _, idx := range m.multiModalSelector.checkOrder {
selected = append(selected, m.multiModalSelector.items[idx].Name)
if m.multiModalSelector.singleAdd != "" {
// Single-add mode: prepend picked model, keep existing deduped
selected = []string{m.multiModalSelector.singleAdd}
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
if name != m.multiModalSelector.singleAdd {
selected = append(selected, name)
}
}
} else {
// Last checked is default (first in result)
co := m.multiModalSelector.checkOrder
last := co[len(co)-1]
selected = []string{m.multiModalSelector.items[last].Name}
for _, idx := range co {
if idx != last {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
}
}
if len(selected) > 0 {
m.changeModels = selected

View File

@@ -106,20 +106,23 @@
"group": "Integrations",
"pages": [
"/integrations/index",
{
"group": "Assistants",
"expanded": true,
"pages": [
"/integrations/openclaw"
]
},
{
"group": "Coding",
"expanded": true,
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose"
]
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
"/integrations/goose",
"/integrations/pi"
]
},
{

View File

@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
- [Pi](/integrations/pi)
## Assistants

57
docs/integrations/pi.mdx Normal file
View File

@@ -0,0 +1,57 @@
---
title: Pi
---
Pi is a minimal AI agent toolkit with plugin support.
## Install
Install [Pi](https://github.com/badlogic/pi-mono):
```bash
npm install -g @mariozechner/pi-coding-agent
```
## Usage with Ollama
### Quick setup
```bash
ollama launch pi
```
To configure without launching:
```shell
ollama launch pi --config
```
### Manual setup
Add a configuration block to `~/.pi/agent/models.json`:
```json
{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{
"id": "qwen3-coder"
}
]
}
}
}
```
Update `~/.pi/agent/settings.json` to set the default provider:
```json
{
"defaultProvider": "ollama",
"defaultModel": "qwen3-coder"
}
```

View File

@@ -27,9 +27,17 @@ The menu provides quick access to:
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
- **Additional integrations** - Available under "More..."
## Assistants
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
```sh
ollama launch openclaw
```
## Coding
Launch coding tools with Ollama models:
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
```sh
ollama launch claude

1
go.mod
View File

@@ -26,6 +26,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/klauspost/compress v1.18.3
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -122,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -150,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
return
}
defer reader.Close()
c.Request.Body = io.NopCloser(reader)
c.Request.Header.Del("Content-Encoding")
}
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
})
}
}
func zstdCompress(t *testing.T, data []byte) []byte {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(data); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return buf.Bytes()
}
func TestResponsesMiddlewareZstd(t *testing.T) {
tests := []struct {
name string
body string
useZstd bool
oversized bool
wantCode int
wantModel string
wantMessage string
}{
{
name: "plain JSON",
body: `{"model": "test-model", "input": "Hello"}`,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd compressed",
body: `{"model": "test-model", "input": "Hello"}`,
useZstd: true,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd over max decompressed size",
oversized: true,
useZstd: true,
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedRequest *api.ChatRequest
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
c.Status(http.StatusOK)
})
var bodyReader io.Reader
if tt.oversized {
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
} else if tt.useZstd {
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
} else {
bodyReader = strings.NewReader(tt.body)
}
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
req.Header.Set("Content-Type", "application/json")
if tt.useZstd || tt.oversized {
req.Header.Set("Content-Encoding", "zstd")
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.wantCode {
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
}
if tt.wantCode != http.StatusOK {
return
}
if capturedRequest == nil {
t.Fatal("expected captured request, got nil")
}
if capturedRequest.Model != tt.wantModel {
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
}
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
}
})
}
}

View File

@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
var p Parser
switch name {
case "qwen3":
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
name string
}{
{"passthrough"},
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"},
{"harmony"},
}

335
model/parsers/qwen3.go Normal file
View File

@@ -0,0 +1,335 @@
package parsers
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwen3ParserState int
const (
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
qwen3ParserStateThinkingStartedEatingWhitespace
qwen3ParserStateCollectingThinking
qwen3ParserStateThinkingDoneEatingWhitespace
qwen3ParserStateCollectingContent
qwen3ParserStateToolStartedEatingWhitespace
qwen3ParserStateCollectingToolContent
)
const (
qwen3ThinkingOpenTag = "<think>"
qwen3ThinkingCloseTag = "</think>"
qwen3ToolOpenTag = "<tool_call>"
qwen3ToolCloseTag = "</tool_call>"
)
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
// with thinking content directly (without an opening tag).
type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
}
func (p *Qwen3Parser) HasToolSupport() bool {
return true
}
func (p *Qwen3Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
thinkingEnabled = p.defaultThinking
}
if p.hasThinkingSupport && thinkingEnabled {
p.state = qwen3ParserStateCollectingThinking
p.maybeThinkingOpenAtBOL = true
} else {
p.state = qwen3ParserStateCollectingContent
p.maybeThinkingOpenAtBOL = false
}
return tools
}
type qwen3Event interface {
isQwen3Event()
}
type qwen3EventContent struct {
content string
}
func (qwen3EventContent) isQwen3Event() {}
type qwen3EventRawToolCall struct {
raw string
}
func (qwen3EventRawToolCall) isQwen3Event() {}
type qwen3EventThinkingContent struct {
content string
}
func (qwen3EventThinkingContent) isQwen3Event() {}
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwen3EventRawToolCall:
toolCall, err := parseQwen3ToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)
case qwen3EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3Parser) parseEvents() []qwen3Event {
var all []qwen3Event
keepLooping := true
for keepLooping {
var events []qwen3Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true
}
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
return splitAtTag(&p.buffer, tag, trimAfter)
}
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
var events []qwen3Event
switch p.state {
case qwen3ParserStateLookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingThinking
}
return events, true
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
} else if trimmed == "" {
return events, false
}
p.state = qwen3ParserStateCollectingContent
return events, true
case qwen3ParserStateThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
case qwen3ParserStateCollectingThinking:
acc := p.buffer.String()
// Some qwen3 checkpoints emit an explicit opening <think> tag even
// though the prompt already ended with <think>. Strip exactly one
// leading opening tag if present.
if p.maybeThinkingOpenAtBOL {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
return events, false
}
p.maybeThinkingOpenAtBOL = false
return events, true
}
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
}
p.maybeThinkingOpenAtBOL = false
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwen3EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
case qwen3ParserStateThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
case qwen3ParserStateCollectingContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolOpenTag) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
case qwen3ParserStateToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
case qwen3ParserStateCollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolCloseTag) {
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("qwen3 tool call closing tag found but no content before it")
}
events = append(events, qwen3EventRawToolCall{raw: toolContent})
p.state = qwen3ParserStateCollectingContent
return events, true
}
return events, false
default:
panic("unreachable")
}
}
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var parsed struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
}
if parsed.Name == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for key, value := range parsed.Arguments {
toolCall.Function.Arguments.Set(key, value)
}
return toolCall, nil
}

147
model/parsers/qwen3_test.go Normal file
View File

@@ -0,0 +1,147 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen3ParserThinkingEnabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<thi", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "" || len(calls) != 0 {
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingDisabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCall(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Fatalf("expected location %q, got %v", "San Francisco", location)
}
unit, ok := calls[0].Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}

View File

@@ -2,6 +2,10 @@
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
# Wrap script in main function so that a truncated partial download doesn't end
# up executing half a script.
main() {
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
@@ -446,3 +450,6 @@ fi
status "NVIDIA GPU ready."
install_success
}
main

View File

@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: &mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",

View File

@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -514,6 +539,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -528,7 +554,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -536,7 +562,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
@@ -684,6 +711,7 @@ type runnerRef struct {
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath]
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()

View File

@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()

View File

@@ -30,6 +30,8 @@ type ModelfileConfig struct {
Template string
System string
License string
Parser string
Renderer string
}
// CreateOptions holds all options for model creation.
@@ -37,7 +39,7 @@ type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
}
// CreateModel imports a model from a local directory.
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: resolveRendererName(opts.Modelfile, rendererName),
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
}
}
func resolveParserName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Parser != "" {
return mf.Parser
}
return inferred
}
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Renderer != "" {
return mf.Renderer
}
return inferred
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}

View File

@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
Parser: "qwen3",
Renderer: "qwen3",
}
if config.Template != "{{ .Prompt }}" {
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
if config.Parser != "qwen3" {
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
}
if config.Renderer != "qwen3" {
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
if config.Parser != "" {
t.Errorf("Parser should be empty, got %q", config.Parser)
}
if config.Renderer != "" {
t.Errorf("Renderer should be empty, got %q", config.Renderer)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
if config.License != "" {
t.Error("License should be empty")
}
if config.Parser != "" {
t.Error("Parser should be empty")
}
if config.Renderer != "" {
t.Error("Renderer should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
Template: "test",
System: "system",
License: "MIT",
Parser: "qwen3-thinking",
Renderer: "qwen3",
},
}
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
if opts.Modelfile.Parser != "qwen3-thinking" {
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
}
if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
}
}
func TestResolveParserName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3",
want: "qwen3",
},
{
name: "empty parser uses inferred",
mf: &ModelfileConfig{
Parser: "",
},
inferred: "qwen3",
want: "qwen3",
},
{
name: "explicit parser overrides inferred",
mf: &ModelfileConfig{
Parser: "qwen3-thinking",
},
inferred: "qwen3",
want: "qwen3-thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveRendererName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "empty renderer uses inferred",
mf: &ModelfileConfig{
Renderer: "",
},
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "explicit renderer overrides inferred",
mf: &ModelfileConfig{
Renderer: "qwen3",
},
inferred: "qwen3-coder",
want: "qwen3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
}
})
}
}
func TestCreateOptions_Defaults(t *testing.T) {

View File

@@ -3,94 +3,65 @@
package mlxrunner
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
Tokens []int32
Caches []cache.Cache
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
// Find longest common prefix
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
prefix++
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
}
r.cache = nil
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -13,6 +13,7 @@ type Cache interface {
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Free()
Offset() int
Len() int
}
@@ -47,6 +48,7 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
}
@@ -73,12 +75,19 @@ func (c *KVCache) Trim(n int) int {
}
func (c *KVCache) Clone() Cache {
return &KVCache{
clone := &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
mlx.Pin(clone.keys, clone.values)
return clone
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
}
func (c *KVCache) Offset() int { return c.offset }
@@ -106,7 +115,8 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
c.keys, c.values = keys.Clone(), values.Clone()
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
@@ -145,6 +155,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
c.idx = prev
}

View File

@@ -3,5 +3,8 @@
package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
)

View File

@@ -7,48 +7,29 @@ import "C"
import (
"encoding/binary"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
ctx C.mlx_array
name string
pinned bool
}
var arrays []*Array
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
return t
}
@@ -133,18 +114,51 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
}
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// lifecycle utilities
// Pin marks arrays as in-use so they are retained during Sweep.
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = true
}
}
}
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = false
}
}
}
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
n := 0
for _, t := range arrays {
if t.pinned && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
arrays = arrays[:n]
}
// misc. utilities
func (t *Array) Valid() bool {
@@ -159,7 +173,10 @@ func (t *Array) String() string {
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Bool("pinned", t.pinned),
}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
@@ -238,37 +255,15 @@ func (t Array) Save(name string) error {
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
for _, t := range arrays {
nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
}

View File

@@ -55,6 +55,30 @@ func tryLoadFromDir(dir string) bool {
return false
}
// tryLoadByName attempts to load the library using just its name,
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
if runtime.GOOS == "linux" {
libraryName = "libmlxc.so"
}
cPath := C.CString(libraryName)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
return false
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
C.mlx_dynamic_unload(&handle)
return false
}
return true
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -73,6 +97,11 @@ func init() {
}
}
// Try loading via rpath/standard library search
if tryLoadByName() {
return
}
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {

View File

@@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
@@ -31,7 +31,7 @@ type LayerNorm struct {
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -41,7 +41,7 @@ type RMSNorm struct {
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -55,7 +55,7 @@ type RoPE struct {
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,

View File

@@ -37,7 +37,9 @@ func Load(path string) iter.Seq2[string, *Array] {
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
arr := New(name)
arr.ctx = value
if !yield(name, arr) {
break
}
}

View File

@@ -10,43 +10,43 @@ import (
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
out := New("ABS")
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
out := New("ADD")
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
out := New("ADDMM")
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
out := New("ARGMAX")
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
out := New("ARGPARTITION")
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
out := New("ARGSORT_AXIS")
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
out := New("AS_TYPE")
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
@@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
out := New("AS_STRIDED")
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
@@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
out := New("CONCATENATE")
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
out := New("DIVIDE")
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
out := New("FLATTEN")
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
out := New("FLOOR_DIVIDE")
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
@@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
out := New("GATHER_MM")
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
out := New("MULTIPLY")
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
out := New("NEGATIVE")
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
out := New("POWER")
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
out := New("PUT_ALONG_AXIS")
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
out := New("RESHAPE")
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
out := New("SIGMOID")
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
out := New("SQRT")
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
out := New("SQUEEZE")
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
out := New("STACK_AXIS")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
out := New("SUBTRACT")
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
out := New("SUM_AXIS")
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
out := New("TAKE_AXIS")
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
out := New("TAKE_ALONG_AXIS")
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
out := New("TANH")
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
@@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
out := New("TRANSPOSE")
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}

View File

@@ -41,14 +41,12 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
inputs := []*Array{w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("DEQUANTIZE", inputs...)
out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
return out
}
@@ -59,14 +57,12 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("QUANTIZED_MATMUL", inputs...)
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
@@ -77,22 +73,18 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
inputs = append(inputs, lhsIndices)
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
inputs = append(inputs, rhsIndices)
}
out := New("GATHER_QMM", inputs...)
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
@@ -104,7 +96,7 @@ func Tile(a *Array, reps []int32) *Array {
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE", a)
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
@@ -116,7 +108,7 @@ func Tri(n, m int32, k int) *Array {
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@@ -131,7 +123,7 @@ func Stack(arrays []*Array, axis int) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK", arrays...)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
@@ -153,13 +145,13 @@ func Take(a *Array, indices *Array, axis int) *Array {
}
func RSqrt(a *Array) *Array {
out := New("RSQRT", a)
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS", a)
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
@@ -235,7 +227,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array {
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE", a)
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
@@ -257,7 +249,7 @@ func SiLU(a *Array) *Array {
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
@@ -289,13 +281,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", q, k, v, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -322,7 +314,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array {
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR", a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -330,7 +322,7 @@ func AddScalar(a *Array, s float32) *Array {
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR", a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -338,7 +330,7 @@ func MulScalar(a *Array, s float32) *Array {
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR", a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out

View File

@@ -7,7 +7,7 @@ import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

View File

@@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
@@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array {
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),

View File

@@ -8,10 +8,10 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.
@@ -78,8 +78,21 @@ func New(root *model.Root) (Model, error) {
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
return func(tensors map[string]*mlx.Array) error {
if err := m.LoadWeights(tensors); err != nil {
return err
}
collected := mlx.Collect(m)
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
return nil
}
}

View File

@@ -0,0 +1,92 @@
//go:build mlx
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
type LinearFactory struct {
tensors map[string]*mlx.Array
defaultGroupSize int
defaultBits int
defaultMode string
tensorQuant map[string]*TensorQuantInfo
}
// NewLinearFactory creates a reusable constructor for model linear layers.
func NewLinearFactory(
tensors map[string]*mlx.Array,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) LinearFactory {
return LinearFactory{
tensors: tensors,
defaultGroupSize: defaultGroupSize,
defaultBits: defaultBits,
defaultMode: defaultMode,
tensorQuant: tensorQuant,
}
}
// Make constructs a linear layer at path.
func (f LinearFactory) Make(path string) nn.LinearLayer {
return MakeLinearLayer(
f.tensors,
path,
f.defaultGroupSize,
f.defaultBits,
f.defaultMode,
f.tensorQuant,
)
}
// MakeLinearLayer constructs a linear layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
// quant params via TensorQuant metadata (with shape-based affine fallback).
// For non-quantized tensors, it returns a standard nn.Linear.
func MakeLinearLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}

130
x/mlxrunner/model/quant.go Normal file
View File

@@ -0,0 +1,130 @@
//go:build mlx
package model
import (
"strings"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
// when available, otherwise falling back to the provided model defaults.
func TensorQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
) (groupSize, bits int, mode string, fromTensor bool) {
if tensorQuant != nil {
if tq := tensorQuant[tensorName]; tq != nil {
groupSize, bits, mode = QuantizationParams(tq.QuantType)
if tq.GroupSize > 0 {
groupSize = tq.GroupSize
}
return groupSize, bits, mode, true
}
}
return defaultGroupSize, defaultBits, defaultMode, false
}
// ResolveLinearQuantParams resolves quantization params for a quantized linear
// tensor, preferring per-tensor metadata and falling back to shape-based
// inference for affine packed tensors.
func ResolveLinearQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
weight, scales *mlx.Array,
) (groupSize, bits int, mode string) {
groupSize, bits, mode, fromTensor := TensorQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
tensorName,
)
if mode == "affine" {
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
if !fromTensor || groupSize == 0 || bits == 0 {
groupSize = inferredGroupSize
bits = inferredBits
}
}
}
return groupSize, bits, mode
}
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
// tensors from packed weight and scale shapes.
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
if weight == nil || scales == nil {
return 0, 0, false
}
weightShape := weight.Dims()
scaleShape := scales.Dims()
if len(weightShape) == 0 || len(scaleShape) == 0 {
return 0, 0, false
}
weightCols := weightShape[len(weightShape)-1]
scalesCols := scaleShape[len(scaleShape)-1]
if weightCols <= 0 || scalesCols <= 0 {
return 0, 0, false
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return 32, 4, true
case groupSize8 == 64:
return 64, 8, true
case groupSize4 == 64 && groupSize8 == 32:
if hintBits == 8 {
return 32, 8, true
}
if hintBits == 4 {
return 64, 4, true
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return groupSize4, 4, true
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return groupSize8, 8, true
}
return 0, 0, false
}
func isCommonGroupSize(v int) bool {
switch v {
case 16, 32, 64, 128:
return true
default:
return false
}
}

View File

@@ -8,42 +8,63 @@ import (
"fmt"
"io"
"os"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
quantType string
groupSize int
// TensorQuantInfo describes per-tensor quantization metadata.
type TensorQuantInfo struct {
QuantType string
GroupSize int
}
// Open loads a manifest for the given model name and pre-scans the first
// tensor blob for quantization metadata (quant_type, group_size).
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
// Backwards-compatible model-level quant metadata (first tensor blob).
quantType string
groupSize int
// Per-tensor quantization metadata.
tensorQuant map[string]*TensorQuantInfo
}
// Open loads a manifest for the given model name and scans tensor blobs for
// quantization metadata.
func Open(modelName string) (*Root, error) {
m, err := manifest.LoadManifest(modelName)
if err != nil {
return nil, err
}
root := &Root{Manifest: m}
root := &Root{
Manifest: m,
tensorQuant: make(map[string]*TensorQuantInfo),
}
// Pre-scan first tensor blob for quantization metadata
for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest)
meta, err := readBlobMetadata(blobPath)
if err != nil || meta == nil {
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
if err != nil {
continue
}
if qt := meta["quant_type"]; qt != "" {
root.quantType = strings.ToUpper(qt)
for name, info := range infos {
root.tensorQuant[name] = info
}
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &root.groupSize)
if root.quantType == "" && blobQuantType != "" {
root.quantType = strings.ToUpper(blobQuantType)
root.groupSize = blobGroupSize
if root.groupSize == 0 {
root.groupSize = defaultGroupSize(root.quantType)
}
}
break // only check the first tensor blob
}
return root, nil
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
// Close is a no-op for now (future: release resources).
func (r *Root) Close() {}
// QuantType returns the quantization type detected from tensor metadata.
// QuantType returns the quantization type detected from the first tensor blob metadata.
func (r *Root) QuantType() string { return r.quantType }
// GroupSize returns the quantization group size detected from tensor metadata.
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
func (r *Root) GroupSize() int { return r.groupSize }
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
func readBlobMetadata(path string) (map[string]string, error) {
// TensorQuant returns per-tensor quantization metadata if available.
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
if r == nil {
return nil
}
return r.tensorQuant[name]
}
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
for k, v := range r.tensorQuant {
if v == nil {
continue
}
copy := *v
out[k] = &copy
}
return out
}
func defaultGroupSize(quantType string) int {
groupSize, _, _ := QuantizationParams(quantType)
return groupSize
}
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
return nil, "", 0, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
return nil, "", 0, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
if headerSize > 100*1024*1024 {
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
return nil, "", 0, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
return nil, "", 0, err
}
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
if _, ok := header[name+".scale"]; !ok {
continue
}
quantType := globalQuantType
groupSize := globalGroupSize
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType
}
if groupSize == 0 {
groupSize = inferredGroup
}
if quantType == "" {
continue
}
if groupSize == 0 {
groupSize = defaultGroupSize(quantType)
}
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
}
return infos, globalQuantType, globalGroupSize, nil
}
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
metaRaw, ok := header["__metadata__"]
if !ok {
return nil, nil
return "", 0
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err
return "", 0
}
return meta, nil
quantType = meta["quant_type"]
if gs := meta["group_size"]; gs != "" {
groupSize, _ = strconv.Atoi(gs)
}
return quantType, groupSize
}
func mainTensorNames(header map[string]json.RawMessage) []string {
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
type tensorShape struct {
Shape []int64 `json:"shape"`
}
mainRaw, ok := header[tensorName]
if !ok {
return "", 0
}
scaleRaw, ok := header[tensorName+".scale"]
if !ok {
return "", 0
}
var mainInfo tensorShape
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
return "", 0
}
var scaleInfo tensorShape
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
return "", 0
}
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
if weightCols <= 0 || scalesCols <= 0 {
return "", 0
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return "INT4", 32
case groupSize8 == 64:
return "INT8", 64
case groupSize4 == 64 && groupSize8 == 32:
h := strings.ToUpper(hintQuantType)
if strings.Contains(h, "8") {
return "INT8", 32
}
if strings.Contains(h, "4") {
return "INT4", 64
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return "INT4", groupSize4
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return "INT8", groupSize8
}
return "", 0
}

View File

@@ -4,11 +4,12 @@ package mlxrunner
import (
"bytes"
"context"
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -18,15 +19,27 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
mlx.EnableCompile()
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
@@ -34,8 +47,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
@@ -54,11 +67,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
var b bytes.Buffer
@@ -67,7 +85,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -80,6 +97,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs = append(outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -91,7 +109,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
Token: int(output),
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
@@ -99,10 +117,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
}
return nil
}
@@ -114,13 +141,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
return flushValidUTF8Prefix(b)
}

View File

@@ -12,12 +12,12 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
type Request struct {
@@ -58,10 +58,10 @@ type Response struct {
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
}
func (r *Runner) Load(modelName string) error {

View File

@@ -40,8 +40,7 @@ func Execute(args []string) error {
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
Requests: make(chan Request),
}
if err := runner.Load(modelName); err != nil {

View File

@@ -0,0 +1,47 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -0,0 +1,46 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}

518
x/models/gemma3/gemma3.go Normal file
View File

@@ -0,0 +1,518 @@
//go:build mlx
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
package gemma3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Gemma3ForCausalLM", newModel)
base.Register("Gemma3ForConditionalGeneration", newModel)
}
// TextConfig holds configuration for the Gemma 3 text model.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network with GELU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *nn.RMSNorm
Attention *Attention
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
MLP *MLP
PostFFNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
// Layer metadata.
IsSliding bool
LayerIdx int32
}
// Model is the Gemma 3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*DecoderLayer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
// Precomputed (1 + weight) for Gemma-style RMSNorm.
NormScaled *mlx.Array
tok *tokenizer.Tokenizer
*TextConfig
weightPrefix string
}
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
switch numLayers {
case 34:
return 8, 4
case 48:
return 16, 8
case 62:
return 32, 16
default:
return 8, 4
}
}
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
var cfg TextConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
}
var wrapped struct {
TextConfig *TextConfig `json:"text_config"`
}
if err := json.Unmarshal(configData, &wrapped); err != nil {
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
}
fromConditional := wrapped.TextConfig != nil
if fromConditional {
cfg = *wrapped.TextConfig
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.MaxPositionEmbeddings == 0 {
cfg.MaxPositionEmbeddings = 131072
}
}
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
return cfg, fromConditional, nil
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
return cfg.LayerTypes[layerIdx] == "sliding_attention"
}
if cfg.SlidingWindowPattern <= 0 {
return false
}
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
}
func precomputeGemmaScaledWeights(m *Model) {
if m.Norm != nil {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
}
var scaled []*mlx.Array
if m.NormScaled != nil {
scaled = append(scaled, m.NormScaled)
}
for _, layer := range m.Layers {
if layer == nil || layer.Attention == nil {
continue
}
if layer.InputNorm != nil {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
scaled = append(scaled, layer.InputNormScaled)
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
scaled = append(scaled, layer.PostAttnNormScaled)
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PreFFNormScaled)
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PostFFNormScaled)
}
if layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.QNormScaled)
}
if layer.Attention.KNorm != nil {
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.KNormScaled)
}
}
if len(scaled) > 0 {
mlx.Eval(scaled...)
}
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
cfg, _, err := parseTextConfig(configData)
if err != nil {
return nil, err
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), m.TextConfig),
}
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Gemma usually ties output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &DecoderLayer{
LayerIdx: i,
IsSliding: isLayerSliding(i, m.TextConfig),
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.InputNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.PostAttnNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.PreFFNorm == nil {
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
}
if layer.PostFFNorm == nil {
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
precomputeGemmaScaledWeights(m)
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// NewCaches creates cache objects for all layers.
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if m.SlidingWindow > 0 && layer.IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// FormatPrompt applies the Gemma 3 chat template.
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}

View File

@@ -8,14 +8,13 @@ import (
"encoding/json"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
@@ -64,9 +63,10 @@ type Config struct {
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// quantizationParams returns groupSize, bits, mode for a quantization type string.
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
// Check if quantized and dequantize
if scales := tensors[path+".weight_scale"]; scales != nil {
qbiases := tensors[path+".weight_qbias"]
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
return embedQ, unembedOut
}
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: cfg.QuantGroupSize,
Bits: cfg.QuantBits,
Mode: cfg.QuantMode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
// no weights loaded yet). Called by the registry via base.New().
func newModel(root *model.Root) (base.Model, error) {
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
// Set up quantization parameters from pre-scanned metadata
if qt := root.QuantType(); qt != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
} else {
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
// Load tokenizer
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
// layer creation.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
cfg := m.Config
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
if !useQuantized && cfg.TensorQuant != nil {
for _, tq := range cfg.TensorQuant {
if tq == nil {
continue
}
_, bits, mode := model.QuantizationParams(tq.QuantType)
if supportsGatherQMM(mode, bits) {
useQuantized = true
break
}
}
}
// Load embedding
if w := tensors["model.embed_tokens.weight"]; w != nil {
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
// Load LM head
m.LMHead = makeLinear(tensors, "lm_head", cfg)
m.LMHead = linears.Make("lm_head")
// Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load attention (same for both block types)
attn := &MLAAttention{}
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
block.MLP = &DenseMLP{
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.up_proj"),
DownProj: linears.Make(prefix + ".mlp.down_proj"),
}
m.Layers[i] = block
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
moeGate := &MoEGate{}
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias
}
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
}
}
@@ -713,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

320
x/models/llama/llama.go Normal file
View File

@@ -0,0 +1,320 @@
//go:build mlx
// Package llama provides a Llama-style decoder-only transformer for MLX.
package llama
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("LlamaForCausalLM", newModel)
}
// Config holds Llama model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
// Model is a Llama text model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-5
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Fallback used by many Llama checkpoints where output is tied.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

335
x/models/qwen3/qwen3.go Normal file
View File

@@ -0,0 +1,335 @@
//go:build mlx
// Package qwen3 provides the Qwen3 text model implementation for MLX.
package qwen3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Qwen3ForCausalLM", newModel)
}
// Config holds Qwen3 model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
QKNormEps float32 `json:"-"`
}
// Model is the Qwen3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
// Layer is a single Qwen3 decoder block.
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
// Attention implements Qwen3 attention with Q/K norms.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
}
// MLP is the feed-forward network with SwiGLU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim == 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.QKNormEps = 1e-6
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Qwen3 checkpoints commonly tie output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.QKNormEps)
k = a.KNorm.Forward(k, cfg.QKNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"os"
"sort"
"strings"
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
}
}
return buildModelInfo(config, totalBytes, tensorCount), nil
info := buildModelInfo(config, totalBytes, tensorCount)
// For quantized models, byte-based estimation can significantly undercount
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info, nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
return info
}
// getParameterCountFromManifest counts model parameters from tensor shapes.
// This accounts for quantized tensors by using unpacked shapes from
// getTensorInfoFromManifest.
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
tensors, err := getTensorInfoFromManifest(mf)
if err != nil {
return 0, err
}
var total int64
for _, tensor := range tensors {
if len(tensor.Shape) == 0 {
continue
}
elements := int64(1)
for _, dim := range tensor.Shape {
if dim == 0 {
elements = 0
break
}
if dim > uint64(math.MaxInt64) {
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
}
d := int64(dim)
if elements > math.MaxInt64/d {
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
}
elements *= d
}
if elements == 0 {
continue
}
if total > math.MaxInt64-elements {
return 0, fmt.Errorf("total parameter count overflow")
}
total += elements
}
return total, nil
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {

View File

@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
}
}
func TestGetParameterCountFromManifest(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Unquantized tensor: [4,5] = 20 params
header1 := map[string]any{
"model.embed_tokens.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{4, 5},
"data_offsets": []int64{0, 40},
},
}
header1JSON, _ := json.Marshal(header1)
var buf1 bytes.Buffer
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
buf1.Write(header1JSON)
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
blobPath1, err := manifest.BlobsPath(digest1)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob1: %v", err)
}
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
header2 := map[string]any{
"__metadata__": map[string]string{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10, 2},
"data_offsets": []int64{0, 80},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{80, 100},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{100, 120},
},
}
header2JSON, _ := json.Marshal(header2)
var buf2 bytes.Buffer
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
buf2.Write(header2JSON)
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
blobPath2, err := manifest.BlobsPath(digest2)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob2: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest1,
Size: int64(buf1.Len() + 40),
Name: "model.embed_tokens.weight",
},
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest2,
Size: int64(buf2.Len() + 120),
Name: "model.layers.0.mlp.up_proj.weight",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 180 // 20 + 160
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Packed mixed-precision blob (no global metadata):
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
header := map[string]any{
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 8},
"data_offsets": []int64{0, 160},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{160, 180},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{180, 200},
},
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 16},
"data_offsets": []int64{200, 520},
},
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{520, 530},
},
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{530, 540},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest,
Size: int64(buf.Len() + 540),
Name: "model.layers.0.mlp.experts",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 640 // 320 + 320
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestParseSafetensorsAllHeaders(t *testing.T) {
tests := []struct {
name string

108
x/tokenizer/tokenizer.go Normal file
View File

@@ -0,0 +1,108 @@
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import "regexp"
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE and SentencePiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
sortedSpecialTokens []string // Special tokens sorted by length, longest first
typ TokenizerType // Algorithm type
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}

View File

@@ -0,0 +1,251 @@
//go:build mlx
package tokenizer
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var (
benchmarkSinkIDs []int32
benchmarkSinkStr string
benchmarkSinkTok *Tokenizer
)
const benchmarkWordPieceJSON = `{
"model": {
"type": "WordPiece",
"vocab": {
"[UNK]": 0,
"hello": 1,
"##world": 2,
"##ly": 3,
"##hello": 4
}
},
"added_tokens": []
}`
const benchmarkSentencePieceJSON = `{
"model": {
"type": "BPE",
"vocab": {
"\u2581": 0,
"h": 1,
"e": 2,
"l": 3,
"o": 4,
"w": 5,
"r": 6,
"d": 7,
"<0x0A>": 8
},
"merges": []
},
"decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "\u2581"
}
}
]
},
"added_tokens": []
}`
func benchmarkMiniLlamaPath(tb testing.TB) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve benchmark file path")
}
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
}
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
tb.Helper()
data := benchmarkLoadMiniLlamaBytes(tb)
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
}
return tok
}
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
tb.Helper()
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
if err != nil {
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
}
return data
}
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
tb.Helper()
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
}
return tok
}
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "short", text: "Hello, world!"},
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(input.text, false)
}
})
}
}
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
}
for _, input := range inputs {
ids := tok.Encode(input.text, false)
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
})
}
}
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
data := benchmarkLoadMiniLlamaBytes(b)
config := &TokenizerConfig{
TokenizerConfigJSON: []byte(`{
"bos_token": {"content": "<|begin_of_text|>"},
"eos_token": {"content": "<|end_of_text|>"},
"add_bos_token": true
}`),
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
}
b.Run("without_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytes(data)
if err != nil {
b.Fatalf("LoadFromBytes failed: %v", err)
}
benchmarkSinkTok = tok
}
})
b.Run("with_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytesWithConfig(data, config)
if err != nil {
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
}
benchmarkSinkTok = tok
}
})
}
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}

View File

@@ -0,0 +1,137 @@
//go:build mlx
package tokenizer
import (
"runtime"
"strings"
"testing"
)
func equalIDs(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestEncodeRoundtripMiniLlama(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
inputs := []string{
"",
"hello",
"hello world",
" hello world ",
"don't we'll they're",
"1234567890",
"こんにちは世界",
"Hello 世界",
"func main() {}",
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
}
for _, input := range inputs {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
}
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
// Simulate construction outside loader path where cache is not set.
tok.sortedSpecialTokens = nil
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
prev := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(prev)
runtime.GOMAXPROCS(1)
seq := tok.Encode(input, false)
if prev < 2 {
runtime.GOMAXPROCS(2)
} else {
runtime.GOMAXPROCS(prev)
}
par := tok.Encode(input, false)
if !equalIDs(seq, par) {
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
}
}

View File

@@ -0,0 +1,56 @@
//go:build mlx
package tokenizer
import (
"strconv"
"strings"
)
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// Mirror GGML tokenizer behavior for NULL byte.
// 0x00 is omitted during decode.
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,289 @@
//go:build mlx
package tokenizer
import (
"runtime"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
const (
encodeParallelMinInputBytes = 4 * 1024
encodeParallelMinChunksPerWorker = 8
)
type tokenMatch struct {
start int
end int
}
type encodeChunk struct {
text string
isSpecial bool
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
tokens := t.sortedSpecialTokens
if len(tokens) == 0 {
// Fallback for tokenizers constructed outside the loaders.
tokens = make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
}
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
m := part[curr.start:curr.end]
nextText := part[next.start:next.end]
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
return
}
firstRune, _ := utf8.DecodeRuneInString(nextText)
if !unicode.IsLetter(firstRune) {
return
}
lastSpaceStart := curr.end
for j := curr.end; j > curr.start; {
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > curr.start {
curr.end = lastSpaceStart
next.start = lastSpaceStart
} else {
next.start = curr.start
curr.end = curr.start
}
}
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
if _, ok := t.specialTokens[part]; ok {
fn(encodeChunk{text: part, isSpecial: true})
return
}
if t.pretokenizer == nil {
fn(encodeChunk{text: part, isSpecial: false})
return
}
re := t.pretokenizer
offset := 0
loc := re.FindStringIndex(part[offset:])
if loc == nil {
return
}
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
for {
loc = re.FindStringIndex(part[offset:])
if loc == nil {
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
return
}
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
adjustWhitespaceBoundary(part, &curr, &next)
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
curr = next
}
}
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
return append(ids, id)
}
return ids
}
return t.encodeChunkInto(c.text, ids)
}
// Encode tokenizes text to token IDs.
// Parallel encoding is used only for very large inputs with enough chunks per worker.
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Fast path: encode sequentially without materializing chunk slices.
if len(s) < encodeParallelMinInputBytes {
var ids []int32
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
ids = t.appendEncodedChunk(ids, c)
})
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// For large inputs collect chunks to enable parallel processing.
var allChunks []encodeChunk
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
allChunks = append(allChunks, c)
})
}
// Encode chunks. Use the parallel path only when the chunk count is
// large enough to amortize goroutine/synchronization overhead.
useParallel := true
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
useParallel = false
}
var ids []int32
if !useParallel {
for _, c := range allChunks {
ids = t.appendEncodedChunk(ids, c)
}
} else {
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []encodeChunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
r = t.appendEncodedChunk(r, c)
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}

View File

@@ -0,0 +1,207 @@
//go:build mlx
package tokenizer
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func llama32GGMLFixturePath(tb testing.TB, file string) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve test file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
}
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
tb.Helper()
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
if err != nil {
tb.Fatalf("failed to open encoder.json: %v", err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
tb.Fatalf("failed to decode encoder.json: %v", err)
}
type addedToken struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
var addedTokens []addedToken
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
id := int32(len(vocab))
vocab[token] = id
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
}
}
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
if err != nil {
tb.Fatalf("failed to open vocab.bpe: %v", err)
}
defer mf.Close()
var merges []string
scanner := bufio.NewScanner(mf)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(line)
if line != "" {
merges = append(merges, line)
}
}
if err := scanner.Err(); err != nil {
tb.Fatalf("failed to read vocab.bpe: %v", err)
}
payload := struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
AddedTokens []addedToken `json:"added_tokens"`
}{}
payload.Model.Type = "BPE"
payload.Model.Vocab = vocab
payload.Model.Merges = merges
payload.PreTokenizer.Type = "Sequence"
payload.PreTokenizer.Pretokenizers = []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}{
{
Type: "Split",
Pattern: struct {
Regex string `json:"Regex"`
}{
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
},
},
}
payload.AddedTokens = addedTokens
data, err := json.Marshal(payload)
if err != nil {
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
}
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
}
return tok
}
func TestGGMLLlamaKnownEncodings(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[string][]int32{
"hello world": {15339, 1917},
"hello <|end_of_text|>": {15339, 220, 128001},
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for input, want := range cases {
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[int][]int32{
1: {15},
2: {410},
3: {931},
4: {931, 15},
5: {931, 410},
6: {931, 931},
7: {931, 931, 15},
8: {931, 931, 410},
9: {931, 931, 931},
10: {931, 931, 931, 15},
11: {931, 931, 931, 410},
12: {931, 931, 931, 931},
13: {931, 931, 931, 931, 15},
14: {931, 931, 931, 931, 410},
15: {931, 931, 931, 931, 931},
16: {931, 931, 931, 931, 931, 15},
17: {931, 931, 931, 931, 931, 410},
}
for n, want := range cases {
input := strings.Repeat("0", n)
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, input := range cases {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
ids := tok.Encode(string(rune(0x00)), false)
got := tok.Decode(ids)
if got != "" {
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
}
}

View File

@@ -0,0 +1,458 @@
//go:build mlx
package tokenizer
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data)
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data)
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Covers SentencePiece and BPE models
if raw.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
var mergesStrings []string
if raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
if len(pair) != 2 {
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
}
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
cacheSortedSpecialTokens(t)
return t, nil
}
func cacheSortedSpecialTokens(t *Tokenizer) {
if len(t.specialTokens) == 0 {
t.sortedSpecialTokens = nil
return
}
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
t.sortedSpecialTokens = tokens
}
type specialTokenConfigData struct {
tokenizerConfigJSON []byte
generationConfigJSON []byte
specialTokensMapJSON []byte
configJSON []byte
}
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.generationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.tokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.specialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
applySpecialTokenConfig(t, specialTokenConfigData{
tokenizerConfigJSON: config.TokenizerConfigJSON,
generationConfigJSON: config.GenerationConfigJSON,
specialTokensMapJSON: config.SpecialTokensMapJSON,
configJSON: config.ConfigJSON,
})
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}

View File

@@ -0,0 +1,26 @@
//go:build mlx
package tokenizer
import (
"strings"
"testing"
)
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
data := []byte(`{
"model": {
"type": "WordPiece",
"vocab": {"[UNK]": 0, "hello": 1}
},
"added_tokens": []
}`)
_, err := LoadFromBytes(data)
if err == nil {
t.Fatal("expected WordPiece load to fail")
}
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
t.Fatalf("unexpected error: %v", err)
}
}