Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
07d944fdfc move tokenizer to separate package 2026-01-28 14:25:31 -08:00
60 changed files with 747 additions and 2010 deletions

View File

@@ -6,8 +6,6 @@ import (
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
@@ -15,15 +13,11 @@ type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extraArgs []string) []string {
var args []string
func (c *Claude) args(model string) []string {
if model != "" {
args = append(args, "--model", model)
return []string{"--model", model}
}
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
return args
return nil
}
func (c *Claude) findPath() (string, error) {
@@ -45,18 +39,18 @@ func (c *Claude) findPath() (string, error) {
return fallback, nil
}
func (c *Claude) Run(model string, extraArgs []string) error {
func (c *Claude) Run(model string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, extraArgs)...)
cmd := exec.Command(claudePath, c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)

View File

@@ -82,23 +82,19 @@ func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
extraArgs []string
want []string
name string
model string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra args", "llama3.2", []string{"--yolo", "--hi"}, []string{"--model", "llama3.2", "--yolo", "--hi"}},
{"empty model with extra args", "", []string{"--help"}, []string{"--help"}},
{"multiple extra args", "llama3.2", []string{"--flag1", "--flag2", "value"}, []string{"--model", "llama3.2", "--flag1", "--flag2", "value"}},
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.extraArgs)
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}

View File

@@ -1,201 +0,0 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
type Clawdbot struct{}
func (c *Clawdbot) String() string { return "Clawdbot" }
const ansiGreen = "\033[32m"
func (c *Clawdbot) Run(model string, extraArgs []string) error {
if _, err := exec.LookPath("clawdbot"); err != nil {
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
}
models := []string{model}
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
// Build args: "gateway" first, then any extra args
args := []string{"gateway"}
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
cmd := exec.Command("clawdbot", args...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
func (c *Clawdbot) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Clawdbot) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Clawdbot) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

View File

@@ -1,625 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestClawdbotIntegration(t *testing.T) {
c := &Clawdbot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Clawdbot" {
t.Errorf("String() = %q, want %q", got, "Clawdbot")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClawdbotEdit(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelExists(t, configPath, "mistral")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
})
}
func TestClawdbotModels(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertClawdbotModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestClawdbotPaths(t *testing.T) {
c := &Clawdbot{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestClawdbotModelsEdgeCases(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestClawdbotEditSchemaFields(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestClawdbotEditModelNames(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "library/model:tag")
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "test-model")
})
}
func TestClawdbotEditAgentsPreservation(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testClawdbotFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestClawdbotEdit_RoundTrip(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestClawdbotEdit_Idempotent(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestClawdbotEdit_BackupCreated(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}

View File

@@ -14,23 +14,20 @@ type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string, extraArgs []string) []string {
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
return args
}
func (c *Codex) Run(model string, extraArgs []string) error {
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model, extraArgs)...)
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -9,22 +9,19 @@ func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
extraArgs []string
want []string
name string
model string
want []string
}{
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and extra args", "qwen3-coder", []string{"--yolo"}, []string{"--oss", "-m", "qwen3-coder", "--yolo"}},
{"empty model with extra args", "", []string{"--help"}, []string{"--oss", "--help"}},
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.extraArgs)
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}

View File

@@ -7,8 +7,6 @@ import (
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Droid implements Runner and Editor for Droid integration
@@ -39,7 +37,7 @@ type modelEntry struct {
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string, extraArgs []string) error {
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
@@ -53,7 +51,7 @@ func (d *Droid) Run(model string, extraArgs []string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid", extraArgs...)
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -119,7 +117,7 @@ func (d *Droid) Edit(models []string) error {
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: envconfig.Host().String() + "/v1",
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,

View File

@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
}
}
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
if model["baseUrl"] != "http://localhost:11434/v1" {
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
}
if model["apiKey"] != "ollama" {
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
{
"model": "existing-ollama-model",
"displayName": "existing-ollama-model",
"baseUrl": "http://127.0.0.1:11434/v1",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 64000,

View File

@@ -22,7 +22,7 @@ import (
// Runner can run an integration with a model.
type Runner interface {
Run(model string, extraArgs []string) error
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
@@ -41,7 +41,6 @@ type Editor interface {
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Clawdbot{},
"codex": &Codex{},
"droid": &Droid{},
"opencode": &OpenCode{},
@@ -222,13 +221,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selected, nil
}
func runIntegration(name, modelName string, extraArgs []string) error {
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, extraArgs)
return r.Run(modelName)
}
// LaunchCmd returns the cobra command for launching integrations.
@@ -237,13 +236,12 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported integrations:
claude Claude Code
clawdbot Clawdbot
codex Codex
droid Droid
opencode OpenCode
@@ -252,17 +250,13 @@ Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)
ollama launch claude -- --yolo --hi (pass extra args to integration)`,
Args: cobra.ArbitraryArgs,
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// Extract integration name and pass through remaining args
var name string
var extraArgs []string
if len(args) > 0 {
name = args[0]
extraArgs = args[1:]
} else {
var err error
name, err = selectIntegration()
@@ -282,7 +276,7 @@ Examples:
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], extraArgs)
return runIntegration(name, config.Models[0])
}
}
@@ -343,13 +337,13 @@ Examples:
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0], extraArgs)
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0], extraArgs)
return runIntegration(name, models[0])
},
}

View File

@@ -90,8 +90,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -121,7 +121,7 @@ func TestLaunchCmd(t *testing.T) {
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model", nil)
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
@@ -182,69 +182,7 @@ func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string, []string) error = r.Run
})
}
}
func TestParseExtraArgs(t *testing.T) {
tests := []struct {
name string
args []string
wantArgs []string
wantExtraArgs []string
}{
{
name: "no extra args",
args: []string{"claude"},
wantArgs: []string{"claude"},
wantExtraArgs: nil,
},
{
name: "with extra args after --",
args: []string{"claude", "--", "--yolo", "--hi"},
wantArgs: []string{"claude"},
wantExtraArgs: []string{"--yolo", "--hi"},
},
{
name: "extra args only after --",
args: []string{"codex", "--", "--help"},
wantArgs: []string{"codex"},
wantExtraArgs: []string{"--help"},
},
{
name: "-- at end with no args after",
args: []string{"claude", "--"},
wantArgs: []string{"claude", "--"},
wantExtraArgs: nil,
},
{
name: "multiple args after --",
args: []string{"claude", "--", "--flag1", "--flag2", "value", "--flag3"},
wantArgs: []string{"claude"},
wantExtraArgs: []string{"--flag1", "--flag2", "value", "--flag3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the parsing logic from LaunchCmd
args := tt.args
var extraArgs []string
for i, arg := range args {
if arg == "--" && i < len(args)-1 {
extraArgs = args[i+1:]
args = args[:i]
break
}
}
if !slices.Equal(args, tt.wantArgs) {
t.Errorf("args = %v, want %v", args, tt.wantArgs)
}
if !slices.Equal(extraArgs, tt.wantExtraArgs) {
t.Errorf("extraArgs = %v, want %v", extraArgs, tt.wantExtraArgs)
}
var _ func(string) error = r.Run
})
}
}

View File

@@ -9,8 +9,6 @@ import (
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
@@ -18,7 +16,7 @@ type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, extraArgs []string) error {
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
@@ -32,7 +30,7 @@ func (o *OpenCode) Run(model string, extraArgs []string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode", extraArgs...)
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -90,7 +88,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
"baseURL": "http://localhost:11434/v1",
},
}
}

View File

@@ -102,7 +102,6 @@
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/clawdbot",
"/integrations/cline",
"/integrations/codex",
"/integrations/droid",

View File

@@ -1,48 +0,0 @@
---
title: Clawdbot
---
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [Clawdbot](https://clawd.bot/)
```bash
npm install -g clawdbot@latest
```
Then run the onboarding wizard:
```bash
clawdbot onboard --install-daemon
```
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch clawdbot
```
This configures Clawdbot to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch clawdbot --config
```
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -115,7 +116,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -141,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tokenizer tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tokenizer, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -154,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tokenizer == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -210,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tokenizer == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -260,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tokenizer != nil,
modelPath,
gpuLibs,
status,
@@ -309,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tokenizer != nil {
return &ollamaServer{llmServer: s, tokenizer: tokenizer}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1772,7 +1773,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1807,7 +1808,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

410
model/ignore_test.go Normal file
View File

File diff suppressed because one or more lines are too long

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
tokenizer := tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -219,8 +220,8 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,7 +45,7 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -58,14 +59,14 @@ func New(c fs.Config) (model.Model, error) {
),
}
processor := model.NewBytePairEncoding(
tokenizer := tokenizer.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
@@ -18,34 +17,12 @@ const (
ministralCollectingToolArgs
)
// ministralEvent represents an event emitted during parsing
type ministralEvent interface {
isMinistralEvent()
}
type ministralEventContent struct {
content string
}
type ministralEventThinking struct {
thinking string
}
type ministralEventToolCall struct {
name string
args string // raw JSON string
}
func (ministralEventContent) isMinistralEvent() {}
func (ministralEventThinking) isMinistralEvent() {}
func (ministralEventToolCall) isMinistralEvent() {}
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
pendingToolName string // stores tool name while collecting args
currentTool *api.Tool
}
func (p *MinistralParser) HasToolSupport() bool {
@@ -86,251 +63,74 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
return nil, fmt.Errorf("tool '%s' not found", n)
}
const (
ministralToolCallsTag = "[TOOL_CALLS]"
ministralThinkTag = "[THINK]"
ministralThinkEndTag = "[/THINK]"
ministralArgsTag = "[ARGS]"
)
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. The second return value indicates
// whether to keep looping (true when state transitions, false when waiting
// for more data).
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
var events []ministralEvent
switch p.state {
case ministralCollectingContent:
bufStr := p.buffer.String()
// Check for [TOOL_CALLS] tag
if strings.Contains(bufStr, ministralToolCallsTag) {
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolName
return events, true
}
// Check for [THINK] tag
if strings.Contains(bufStr, ministralThinkTag) {
split := strings.SplitN(bufStr, ministralThinkTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingThinkingContent
return events, true
}
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
overlapThink := overlap(bufStr, ministralThinkTag)
maxOverlap := max(overlapToolCalls, overlapThink)
if maxOverlap > 0 {
// Withhold the potential partial tag
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
trailingWS := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWS
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
}
// No tag found: emit content but withhold trailing whitespace
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
case ministralCollectingThinkingContent:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralThinkEndTag) {
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
thinkingContent := split[0]
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if len(thinkingContent) > 0 {
events = append(events, ministralEventThinking{thinking: thinkingContent})
}
p.state = ministralCollectingContent
return events, true
}
// Check for partial overlap with [/THINK]
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventThinking{thinking: unambiguous})
}
return events, false
}
// No tag found: emit all thinking content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, ministralEventThinking{thinking: bufStr})
}
return events, false
case ministralCollectingToolName:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralArgsTag) {
split := strings.SplitN(bufStr, ministralArgsTag, 2)
toolName := split[0]
after := split[1]
p.pendingToolName = toolName
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolArgs
return events, true
}
// Wait for more data
return events, false
case ministralCollectingToolArgs:
bufStr := p.buffer.String()
jsonEnd := findJSONEnd(bufStr)
if jsonEnd != -1 {
jsonStr := bufStr[:jsonEnd+1]
remaining := bufStr[jsonEnd+1:]
events = append(events, ministralEventToolCall{
name: p.pendingToolName,
args: jsonStr,
})
p.pendingToolName = ""
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = ministralCollectingContent
return events, true
}
// Wait for more data
return events, false
default:
panic("unexpected ministral event")
}
}
// parseEvents loops calling eat() until it returns false
func (p *MinistralParser) parseEvents() []ministralEvent {
var all []ministralEvent
keepLooping := true
for keepLooping {
var events []ministralEvent
events, keepLooping = p.eat()
all = append(all, events...)
}
return all
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentBuilder, thinkingBuilder strings.Builder
var toolCalls []api.ToolCall
for _, event := range events {
switch e := event.(type) {
case ministralEventContent:
contentBuilder.WriteString(e.content)
case ministralEventThinking:
thinkingBuilder.WriteString(e.thinking)
case ministralEventToolCall:
// Validate tool exists
tool, toolErr := toolByName(p.tools, e.name)
if toolErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
}
// Parse JSON arguments
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
}
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var args api.ToolCallFunctionArguments
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
if err := json.Unmarshal([]byte(before), &args); err != nil {
// todo - throw a better error
return "", "", calls, err
}
toolCalls = append(toolCalls, api.ToolCall{
p.state = ministralCollectingContent
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: tool.Function.Name,
Name: p.currentTool.Function.Name,
Arguments: args,
},
})
}
calls = append(calls, call)
return "", "", calls, nil
}
return "", "", calls, nil
}
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
}
// findJSONEnd finds the index of the closing brace that completes a JSON object.
// It properly handles nested objects, arrays, and strings (including escaped characters).
// Returns -1 if the JSON is not yet complete.
func findJSONEnd(s string) int {
depth := 0
inString := false
escaped := false
for i, r := range s {
if inString {
switch {
case escaped:
// If the previous character was a backslash, skip this character
escaped = false
case r == '\\':
// Mark the next character as escaped
escaped = true
case r == '"':
// End of string literal
inString = false
}
continue
}
switch r {
case '"':
// Start of string literal
inString = true
case '{', '[':
// Increase nesting level for objects and arrays
depth++
case '}', ']':
// Decrease nesting level
depth--
if depth == 0 {
// Reached the end of the root JSON structure
return i
}
}
}
return -1
return p.buffer.String(), thinking, calls, nil
}

View File

@@ -1,545 +0,0 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestMinistralParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []ministralEvent
}
cases := []struct {
desc string
tools []api.Tool
steps []step
think bool // whether to enable thinking support
}{
// Content streaming
{
desc: "simple content",
steps: []step{
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
ministralEventContent{content: "Hello, how can I help you?"},
}},
},
},
{
desc: "streaming content word by word",
steps: []step{
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
},
},
// Simple tool calls
{
desc: "simple tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
}},
},
},
{
desc: "tool call with nested object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "tool call with deeply nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
steps: []step{
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
}},
},
},
{
desc: "tool call with array of objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
steps: []step{
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
}},
},
},
{
desc: "tool call with escaped quotes in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
steps: []step{
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
}},
},
},
{
desc: "tool call with braces inside string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
steps: []step{
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
}},
},
},
{
desc: "empty JSON object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
steps: []step{
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "no_args", args: `{}`},
}},
},
},
{
desc: "JSON with newlines in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
steps: []step{
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
}},
},
},
{
desc: "backslash in string value",
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
steps: []step{
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
}},
},
},
// Content after tool call
{
desc: "content after tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// NOTE: It's unclear if this is valid Ministral output, but the parser
// currently treats text after a tool call as regular content. This test
// documents that behavior so we notice if it changes.
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": 1}`},
ministralEventContent{content: "some content after"},
}},
},
},
// Multiple tool calls
{
desc: "multiple tool calls in sequence",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "get_weather"}},
{Function: api.ToolFunction{Name: "get_time"}},
},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
}},
},
},
{
desc: "multiple tool calls streamed separately",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "tool_a"}},
{Function: api.ToolFunction{Name: "tool_b"}},
},
steps: []step{
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
}},
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
}},
},
},
// Streaming tool calls
{
desc: "streaming tool call with nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
{input: `]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "streaming with incomplete JSON waits for completion",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
{input: `"a": {`, wantEvents: []ministralEvent{}},
{input: `"b": 1`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
}},
},
},
// Partial tag handling
{
desc: "partial tool tag fakeout",
steps: []step{
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
},
},
{
desc: "tool call tag split across chunks",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_", wantEvents: []ministralEvent{}},
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "content before tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "hello"},
ministralEventToolCall{name: "get_weather", args: `{}`},
}},
},
},
{
desc: "whitespace between content and tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "tabs and newlines before tool call are trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "non-breaking space before tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "whitespace before THINK tag is trimmed",
steps: []step{
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "after"},
}},
},
},
{
desc: "trailing whitespace withheld then emitted",
steps: []step{
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
},
},
{
desc: "trailing newline withheld then emitted",
steps: []step{
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
},
},
// Thinking support
{
desc: "thinking content",
think: true,
steps: []step{
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking here"},
}},
{input: "content after", wantEvents: []ministralEvent{
ministralEventContent{content: "content after"},
}},
},
},
{
desc: "thinking with whitespace after end tag",
think: true,
steps: []step{
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "my thoughts"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "non-breaking space after think end tag is trimmed",
think: true,
steps: []step{
// \u00a0 is non-breaking space
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "partial think end tag",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
},
},
{
desc: "think tag fakeout",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
},
},
{
desc: "thinking then tool call",
think: true,
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "let me think"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
// Content then THINK tag transition
{
desc: "content then think tag",
steps: []step{
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "more"},
}},
},
},
// Unicode handling
{
desc: "unicode content",
steps: []step{
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
ministralEventContent{content: "你好 🌍 مرحبا"},
}},
},
},
{
desc: "unicode in tool args",
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
steps: []step{
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
}},
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := MinistralParser{}
parser.hasThinkingSupport = tc.think
parser.Init(tc.tools, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestMinistralParser_Errors(t *testing.T) {
t.Run("unknown tool returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
if err == nil {
t.Fatal("expected error for unknown tool")
}
})
t.Run("invalid JSON returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
})
}
func TestFindJSONEnd(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "simple object",
input: `{"a": 1}`,
expected: 7,
},
{
name: "nested object",
input: `{"a": {"b": 2}}`,
expected: 14,
},
{
name: "array inside object",
input: `{"items": [1, 2, 3]}`,
expected: 19,
},
{
name: "braces in string",
input: `{"template": "Hello {name}!"}`,
expected: 28,
},
{
name: "escaped quotes",
input: `{"msg": "say \"hi\""}`,
expected: 20,
},
{
name: "incomplete object",
input: `{"a": {"b": 1}`,
expected: -1,
},
{
name: "deeply nested",
input: `{"a": {"b": {"c": {"d": 1}}}}`,
expected: 28,
},
{
name: "object with trailing content",
input: `{"a": 1} extra`,
expected: 7,
},
{
name: "array",
input: `[{"a": 1}, {"b": 2}]`,
expected: 19,
},
{
name: "escaped backslash before quote",
input: `{"path": "C:\\"}`,
expected: 15,
},
{
name: "empty string",
input: "",
expected: -1,
},
{
name: "no opening brace",
input: "hello world",
expected: -1,
},
{
name: "only opening brace",
input: "{",
expected: -1,
},
{
name: "unclosed string",
input: `{"key": "unclosed`,
expected: -1,
},
{
name: "double escaped backslash then quote",
input: `{"path": "C:\\\\"}`,
expected: 17,
},
{
name: "unicode in key and value",
input: `{"키": "값"}`,
expected: 13,
},
{
name: "nested arrays",
input: `{"matrix": [[1, 2], [3, 4]]}`,
expected: 27,
},
{
name: "mixed nesting",
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
expected: 31,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findJSONEnd(tt.input)
if result != tt.expected {
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
}
})
}
}
func TestMinistralParser_HasToolSupport(t *testing.T) {
p := &MinistralParser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
p := &MinistralParser{hasThinkingSupport: false}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
p = &MinistralParser{hasThinkingSupport: true}
if !p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return true")
}
}

View File

@@ -3,7 +3,6 @@ package parsers
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
@@ -115,33 +114,3 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
sb.WriteString(after)
return before, after // return events
}
// overlap returns the longest overlap between the suffix of s and the prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}

View File

@@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
@@ -193,6 +194,36 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tokenizer tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tokenizer.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -766,7 +767,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -775,14 +776,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -873,7 +874,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tokenizer tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tokenizer.Vocabulary().Values))
pieces := make([]string, len(tokenizer.Vocabulary().Values))
for i := range tokenizer.Vocabulary().Values {
pieces[i], _ = tokenizer.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tokenizer.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.Join("..", "testdata", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -1,8 +1,10 @@
package model
package tokenizer
import (
"cmp"
"fmt"
"iter"
"log/slog"
"slices"
"strings"
@@ -11,24 +13,24 @@ import (
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
type bytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*bytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) bytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
return bytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
@@ -37,15 +39,15 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEnc
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
func (bpe bytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
func (bpe bytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
func (bpe *bytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
@@ -96,7 +98,7 @@ type merge struct {
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
func (bpe bytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
@@ -243,7 +245,15 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
type lazyIdsString struct {
ids []int32
}
func (l lazyIdsString) LogValue() slog.Value {
return slog.AnyValue(fmt.Sprint(l.ids))
}
func (bpe bytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
@@ -267,6 +277,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"bufio"
@@ -14,10 +14,10 @@ import (
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
func llama(t testing.TB) bytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"container/heap"
@@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁"
type SentencePiece struct {
type sentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*sentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
func (spm sentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
func NewSentencePiece(vocab *Vocabulary) sentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -42,17 +42,17 @@ func NewSentencePiece(vocab *Vocabulary) SentencePiece {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
return sentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
func (spm sentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
func (spm sentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@@ -218,13 +218,13 @@ func (q *queue) Pop() interface{} {
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
func (spm sentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"
@@ -12,10 +12,10 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePiece {
func loadSentencePieceVocab(t *testing.T) sentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"fmt"
@@ -9,7 +9,7 @@ import (
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
type wordPiece struct {
vocab *Vocabulary
lowercase bool
}
@@ -32,8 +32,8 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
// Decode implements Tokenizer.
func (wpm wordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
@@ -56,7 +56,7 @@ func (wpm WordPiece) Decode(ids []int32) (string, error) {
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
func (wpm wordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
@@ -96,8 +96,8 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
// Encode implements Tokenizer.
func (wpm wordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
@@ -151,20 +151,20 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
// Is implements Tokenizer.
func (wpm wordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
// Vocabulary implements Tokenizer.
func (wpm wordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*wordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{
func NewWordPiece(vocab *Vocabulary, lowercase bool) wordPiece {
return wordPiece{
vocab: vocab,
lowercase: lowercase,
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"slices"
@@ -39,7 +39,7 @@ func TestWordPiece(t *testing.T) {
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
var wpm wordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {

View File

@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.Join("..", "..", "tokenizer", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.Join("..", "..", "tokenizer", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -9,6 +9,7 @@ import (
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
@@ -18,7 +19,7 @@ import (
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*VisionModel `gguf:"vision_tower.vision_model"`
*TextModel `gguf:"language_model.model"`
@@ -58,8 +59,8 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) {
// slog.Info("XXX Config", "c", c)
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.Join("..", "..", "tokenizer", "testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}