mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 09:43:35 -05:00
Compare commits
1 Commits
parth-laun
...
mxyng/toke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07d944fdfc |
@@ -6,8 +6,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
@@ -15,15 +13,11 @@ type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extraArgs []string) []string {
|
||||
var args []string
|
||||
func (c *Claude) args(model string) []string {
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
return []string{"--model", model}
|
||||
}
|
||||
if len(extraArgs) > 0 {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
return args
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
@@ -45,18 +39,18 @@ func (c *Claude) findPath() (string, error) {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, extraArgs []string) error {
|
||||
func (c *Claude) Run(model string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, extraArgs)...)
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
@@ -82,23 +82,19 @@ func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
extraArgs []string
|
||||
want []string
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and extra args", "llama3.2", []string{"--yolo", "--hi"}, []string{"--model", "llama3.2", "--yolo", "--hi"}},
|
||||
{"empty model with extra args", "", []string{"--help"}, []string{"--help"}},
|
||||
{"multiple extra args", "llama3.2", []string{"--flag1", "--flag2", "value"}, []string{"--model", "llama3.2", "--flag1", "--flag2", "value"}},
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.extraArgs)
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Clawdbot struct{}
|
||||
|
||||
func (c *Clawdbot) String() string { return "Clawdbot" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Clawdbot) Run(model string, extraArgs []string) error {
|
||||
if _, err := exec.LookPath("clawdbot"); err != nil {
|
||||
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
// Build args: "gateway" first, then any extra args
|
||||
args := []string{"gateway"}
|
||||
if len(extraArgs) > 0 {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
cmd := exec.Command("clawdbot", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,625 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClawdbotIntegration(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Clawdbot" {
|
||||
t.Errorf("String() = %q, want %q", got, "Clawdbot")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEdit(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelExists(t, configPath, "mistral")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["mcp"] == nil {
|
||||
t.Error("mcp was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
// User adds custom field
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
entry["customField"] = "user-value"
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
models = cfg["models"].(map[string]any)
|
||||
providers = models["providers"].(map[string]any)
|
||||
ollama = providers["ollama"].(map[string]any)
|
||||
modelList = ollama["models"].([]any)
|
||||
entry = modelList[0].(map[string]any)
|
||||
if entry["customField"] != "user-value" {
|
||||
t.Error("custom field was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("edit replaces models list", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
original := `{"existing":"data"}`
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
c.Edit([]string{})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != original {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Error("result should be valid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type models section", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModels(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("no config returns nil", func(t *testing.T) {
|
||||
if models := c.Models(); len(models) > 0 {
|
||||
t.Errorf("expected nil/empty, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %v", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertClawdbotModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models, _ := cfg["models"].(map[string]any)
|
||||
providers, _ := models["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
t.Errorf("model %s should not exist", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
model := defaults["model"].(map[string]any)
|
||||
if model["primary"] != expected {
|
||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotPaths(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModelsEdgeCases(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at models level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model entry missing id", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for missing id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model id is not string", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for non-string id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditSchemaFields(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
}
|
||||
if cost["cacheWrite"] == nil {
|
||||
t.Error("cost.cacheWrite should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEditModelNames(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "library/model:tag")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature setting was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent was lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const testClawdbotFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
||||
"custom-agent": {"foo": "bar"}
|
||||
}
|
||||
}`
|
||||
|
||||
func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
// Verify top-level preserved
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme not preserved")
|
||||
}
|
||||
mcp := cfg["mcp"].(map[string]any)
|
||||
servers := mcp["servers"].(map[string]any)
|
||||
if servers["custom"] == nil {
|
||||
t.Error("mcp.servers.custom not preserved")
|
||||
}
|
||||
|
||||
// Verify other providers preserved
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider not preserved")
|
||||
}
|
||||
|
||||
// Verify agents preserved
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent not preserved")
|
||||
}
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_Idempotent(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
secondData, _ := os.ReadFile(configPath)
|
||||
|
||||
if string(firstData) != string(secondData) {
|
||||
t.Error("repeated edits with same models produced different results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
if i%2 == 0 {
|
||||
models = []string{"model-x", "model-y", "model-z"}
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
t.Fatalf("edit %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
||||
}
|
||||
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme lost after multiple edits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
if string(data) == original {
|
||||
foundBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup with original content not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
}
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
@@ -14,23 +14,20 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string, extraArgs []string) []string {
|
||||
func (c *Codex) args(model string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
if len(extraArgs) > 0 {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string, extraArgs []string) error {
|
||||
func (c *Codex) Run(model string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model, extraArgs)...)
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -9,22 +9,19 @@ func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
extraArgs []string
|
||||
want []string
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and extra args", "qwen3-coder", []string{"--yolo"}, []string{"--oss", "-m", "qwen3-coder", "--yolo"}},
|
||||
{"empty model with extra args", "", []string{"--help"}, []string{"--oss", "--help"}},
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.extraArgs)
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
@@ -39,7 +37,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string, extraArgs []string) error {
|
||||
func (d *Droid) Run(model string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -53,7 +51,7 @@ func (d *Droid) Run(model string, extraArgs []string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", extraArgs...)
|
||||
cmd := exec.Command("droid")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -119,7 +117,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
|
||||
{
|
||||
"model": "existing-ollama-model",
|
||||
"displayName": "existing-ollama-model",
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"apiKey": "ollama",
|
||||
"provider": "generic-chat-completion-api",
|
||||
"maxOutputTokens": 64000,
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string, extraArgs []string) error
|
||||
Run(model string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
@@ -41,7 +41,6 @@ type Editor interface {
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Clawdbot{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
@@ -222,13 +221,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, extraArgs []string) error {
|
||||
func runIntegration(name, modelName string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, extraArgs)
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
@@ -237,13 +236,12 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Use: "launch [INTEGRATION]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
clawdbot Clawdbot
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
@@ -252,17 +250,13 @@ Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch claude -- --yolo --hi (pass extra args to integration)`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Extract integration name and pass through remaining args
|
||||
var name string
|
||||
var extraArgs []string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
extraArgs = args[1:]
|
||||
} else {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
@@ -282,7 +276,7 @@ Examples:
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], extraArgs)
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,13 +337,13 @@ Examples:
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0], extraArgs)
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0], extraArgs)
|
||||
return runIntegration(name, models[0])
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -90,8 +90,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -121,7 +121,7 @@ func TestLaunchCmd(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -182,69 +182,7 @@ func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string, []string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtraArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantArgs []string
|
||||
wantExtraArgs []string
|
||||
}{
|
||||
{
|
||||
name: "no extra args",
|
||||
args: []string{"claude"},
|
||||
wantArgs: []string{"claude"},
|
||||
wantExtraArgs: nil,
|
||||
},
|
||||
{
|
||||
name: "with extra args after --",
|
||||
args: []string{"claude", "--", "--yolo", "--hi"},
|
||||
wantArgs: []string{"claude"},
|
||||
wantExtraArgs: []string{"--yolo", "--hi"},
|
||||
},
|
||||
{
|
||||
name: "extra args only after --",
|
||||
args: []string{"codex", "--", "--help"},
|
||||
wantArgs: []string{"codex"},
|
||||
wantExtraArgs: []string{"--help"},
|
||||
},
|
||||
{
|
||||
name: "-- at end with no args after",
|
||||
args: []string{"claude", "--"},
|
||||
wantArgs: []string{"claude", "--"},
|
||||
wantExtraArgs: nil,
|
||||
},
|
||||
{
|
||||
name: "multiple args after --",
|
||||
args: []string{"claude", "--", "--flag1", "--flag2", "value", "--flag3"},
|
||||
wantArgs: []string{"claude"},
|
||||
wantExtraArgs: []string{"--flag1", "--flag2", "value", "--flag3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the parsing logic from LaunchCmd
|
||||
args := tt.args
|
||||
var extraArgs []string
|
||||
for i, arg := range args {
|
||||
if arg == "--" && i < len(args)-1 {
|
||||
extraArgs = args[i+1:]
|
||||
args = args[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(args, tt.wantArgs) {
|
||||
t.Errorf("args = %v, want %v", args, tt.wantArgs)
|
||||
}
|
||||
if !slices.Equal(extraArgs, tt.wantExtraArgs) {
|
||||
t.Errorf("extraArgs = %v, want %v", extraArgs, tt.wantExtraArgs)
|
||||
}
|
||||
var _ func(string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -18,7 +16,7 @@ type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, extraArgs []string) error {
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -32,7 +30,7 @@ func (o *OpenCode) Run(model string, extraArgs []string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode", extraArgs...)
|
||||
cmd := exec.Command("opencode")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -90,7 +88,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,6 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/clawdbot",
|
||||
"/integrations/cline",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
---
|
||||
title: Clawdbot
|
||||
---
|
||||
|
||||
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Clawdbot](https://clawd.bot/)
|
||||
|
||||
```bash
|
||||
npm install -g clawdbot@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
clawdbot onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch clawdbot
|
||||
```
|
||||
|
||||
This configures Clawdbot to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch clawdbot --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -115,7 +116,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // textProcessor handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -141,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tokenizer tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tokenizer, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tokenizer == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -210,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tokenizer == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -260,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tokenizer != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -309,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tokenizer != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tokenizer}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1772,7 +1773,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1807,7 +1808,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
410
model/ignore_test.go
Normal file
410
model/ignore_test.go
Normal file
File diff suppressed because one or more lines are too long
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
tokenizer := tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -219,8 +220,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer,
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,7 +45,7 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -58,14 +59,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
tokenizer := tokenizer.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -18,34 +17,12 @@ const (
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
// ministralEvent represents an event emitted during parsing
|
||||
type ministralEvent interface {
|
||||
isMinistralEvent()
|
||||
}
|
||||
|
||||
type ministralEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ministralEventThinking struct {
|
||||
thinking string
|
||||
}
|
||||
|
||||
type ministralEventToolCall struct {
|
||||
name string
|
||||
args string // raw JSON string
|
||||
}
|
||||
|
||||
func (ministralEventContent) isMinistralEvent() {}
|
||||
func (ministralEventThinking) isMinistralEvent() {}
|
||||
func (ministralEventToolCall) isMinistralEvent() {}
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
pendingToolName string // stores tool name while collecting args
|
||||
currentTool *api.Tool
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
@@ -86,251 +63,74 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
const (
|
||||
ministralToolCallsTag = "[TOOL_CALLS]"
|
||||
ministralThinkTag = "[THINK]"
|
||||
ministralThinkEndTag = "[/THINK]"
|
||||
ministralArgsTag = "[ARGS]"
|
||||
)
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. The second return value indicates
|
||||
// whether to keep looping (true when state transitions, false when waiting
|
||||
// for more data).
|
||||
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
|
||||
var events []ministralEvent
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Check for [TOOL_CALLS] tag
|
||||
if strings.Contains(bufStr, ministralToolCallsTag) {
|
||||
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolName
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for [THINK] tag
|
||||
if strings.Contains(bufStr, ministralThinkTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
|
||||
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
|
||||
overlapThink := overlap(bufStr, ministralThinkTag)
|
||||
maxOverlap := max(overlapToolCalls, overlapThink)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
// Withhold the potential partial tag
|
||||
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
|
||||
trailingWS := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWS
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit content but withhold trailing whitespace
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingThinkingContent:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralThinkEndTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
|
||||
thinkingContent := split[0]
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if len(thinkingContent) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: thinkingContent})
|
||||
}
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial overlap with [/THINK]
|
||||
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit all thinking content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolName:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralArgsTag) {
|
||||
split := strings.SplitN(bufStr, ministralArgsTag, 2)
|
||||
toolName := split[0]
|
||||
after := split[1]
|
||||
p.pendingToolName = toolName
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolArgs
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolArgs:
|
||||
bufStr := p.buffer.String()
|
||||
jsonEnd := findJSONEnd(bufStr)
|
||||
|
||||
if jsonEnd != -1 {
|
||||
jsonStr := bufStr[:jsonEnd+1]
|
||||
remaining := bufStr[jsonEnd+1:]
|
||||
|
||||
events = append(events, ministralEventToolCall{
|
||||
name: p.pendingToolName,
|
||||
args: jsonStr,
|
||||
})
|
||||
|
||||
p.pendingToolName = ""
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unexpected ministral event")
|
||||
}
|
||||
}
|
||||
|
||||
// parseEvents loops calling eat() until it returns false
|
||||
func (p *MinistralParser) parseEvents() []ministralEvent {
|
||||
var all []ministralEvent
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []ministralEvent
|
||||
events, keepLooping = p.eat()
|
||||
all = append(all, events...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentBuilder, thinkingBuilder strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case ministralEventContent:
|
||||
contentBuilder.WriteString(e.content)
|
||||
case ministralEventThinking:
|
||||
thinkingBuilder.WriteString(e.thinking)
|
||||
case ministralEventToolCall:
|
||||
// Validate tool exists
|
||||
tool, toolErr := toolByName(p.tools, e.name)
|
||||
if toolErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
}
|
||||
// Parse JSON arguments
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
}
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: tool.Function.Name,
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
// findJSONEnd finds the index of the closing brace that completes a JSON object.
|
||||
// It properly handles nested objects, arrays, and strings (including escaped characters).
|
||||
// Returns -1 if the JSON is not yet complete.
|
||||
func findJSONEnd(s string) int {
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range s {
|
||||
if inString {
|
||||
switch {
|
||||
case escaped:
|
||||
// If the previous character was a backslash, skip this character
|
||||
escaped = false
|
||||
case r == '\\':
|
||||
// Mark the next character as escaped
|
||||
escaped = true
|
||||
case r == '"':
|
||||
// End of string literal
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
// Start of string literal
|
||||
inString = true
|
||||
case '{', '[':
|
||||
// Increase nesting level for objects and arrays
|
||||
depth++
|
||||
case '}', ']':
|
||||
// Decrease nesting level
|
||||
depth--
|
||||
if depth == 0 {
|
||||
// Reached the end of the root JSON structure
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
}
|
||||
|
||||
@@ -1,545 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMinistralParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []ministralEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
tools []api.Tool
|
||||
steps []step
|
||||
think bool // whether to enable thinking support
|
||||
}{
|
||||
// Content streaming
|
||||
{
|
||||
desc: "simple content",
|
||||
steps: []step{
|
||||
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "Hello, how can I help you?"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming content word by word",
|
||||
steps: []step{
|
||||
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
|
||||
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
|
||||
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Simple tool calls
|
||||
{
|
||||
desc: "simple tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with nested object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with deeply nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with array of objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with escaped quotes in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with braces inside string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty JSON object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "no_args", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON with newlines in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "backslash in string value",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content after tool call
|
||||
{
|
||||
desc: "content after tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// NOTE: It's unclear if this is valid Ministral output, but the parser
|
||||
// currently treats text after a tool call as regular content. This test
|
||||
// documents that behavior so we notice if it changes.
|
||||
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": 1}`},
|
||||
ministralEventContent{content: "some content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Multiple tool calls
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "get_weather"}},
|
||||
{Function: api.ToolFunction{Name: "get_time"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
|
||||
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls streamed separately",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "tool_a"}},
|
||||
{Function: api.ToolFunction{Name: "tool_b"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
|
||||
}},
|
||||
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Streaming tool calls
|
||||
{
|
||||
desc: "streaming tool call with nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
|
||||
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
|
||||
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
|
||||
{input: `]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming with incomplete JSON waits for completion",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
|
||||
{input: `"a": {`, wantEvents: []ministralEvent{}},
|
||||
{input: `"b": 1`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Partial tag handling
|
||||
{
|
||||
desc: "partial tool tag fakeout",
|
||||
steps: []step{
|
||||
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
|
||||
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tag split across chunks",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_", wantEvents: []ministralEvent{}},
|
||||
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "hello"},
|
||||
ministralEventToolCall{name: "get_weather", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace between content and tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tabs and newlines before tool call are trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space before tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
|
||||
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace before THINK tag is trimmed",
|
||||
steps: []step{
|
||||
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing newline withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Thinking support
|
||||
{
|
||||
desc: "thinking content",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking here"},
|
||||
}},
|
||||
{input: "content after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking with whitespace after end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "my thoughts"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space after think end tag is trimmed",
|
||||
think: true,
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space
|
||||
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial think end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag fakeout",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking then tool call",
|
||||
think: true,
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "let me think"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content then THINK tag transition
|
||||
{
|
||||
desc: "content then think tag",
|
||||
steps: []step{
|
||||
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "more"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Unicode handling
|
||||
{
|
||||
desc: "unicode content",
|
||||
steps: []step{
|
||||
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "你好 🌍 مرحبا"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "unicode in tool args",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := MinistralParser{}
|
||||
parser.hasThinkingSupport = tc.think
|
||||
parser.Init(tc.tools, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_Errors(t *testing.T) {
|
||||
t.Run("unknown tool returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown tool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindJSONEnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
input: `{"a": 1}`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
input: `{"a": {"b": 2}}`,
|
||||
expected: 14,
|
||||
},
|
||||
{
|
||||
name: "array inside object",
|
||||
input: `{"items": [1, 2, 3]}`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "braces in string",
|
||||
input: `{"template": "Hello {name}!"}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
input: `{"msg": "say \"hi\""}`,
|
||||
expected: 20,
|
||||
},
|
||||
{
|
||||
name: "incomplete object",
|
||||
input: `{"a": {"b": 1}`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
input: `{"a": {"b": {"c": {"d": 1}}}}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "object with trailing content",
|
||||
input: `{"a": 1} extra`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
input: `[{"a": 1}, {"b": 2}]`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "escaped backslash before quote",
|
||||
input: `{"path": "C:\\"}`,
|
||||
expected: 15,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "no opening brace",
|
||||
input: "hello world",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "only opening brace",
|
||||
input: "{",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "unclosed string",
|
||||
input: `{"key": "unclosed`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "double escaped backslash then quote",
|
||||
input: `{"path": "C:\\\\"}`,
|
||||
expected: 17,
|
||||
},
|
||||
{
|
||||
name: "unicode in key and value",
|
||||
input: `{"키": "값"}`,
|
||||
expected: 13,
|
||||
},
|
||||
{
|
||||
name: "nested arrays",
|
||||
input: `{"matrix": [[1, 2], [3, 4]]}`,
|
||||
expected: 27,
|
||||
},
|
||||
{
|
||||
name: "mixed nesting",
|
||||
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
|
||||
expected: 31,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findJSONEnd(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasToolSupport(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
|
||||
p := &MinistralParser{hasThinkingSupport: false}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
|
||||
p = &MinistralParser{hasThinkingSupport: true}
|
||||
if !p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return true")
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package parsers
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
@@ -115,33 +114,3 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
// overlap returns the longest overlap between the suffix of s and the prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -193,6 +194,36 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tokenizer tokenizer.Tokenizer) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
text, _ := tokenizer.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -766,7 +767,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -775,14 +776,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
@@ -873,7 +874,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(tokenizer tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tokenizer.Vocabulary().Values))
|
||||
pieces := make([]string, len(tokenizer.Vocabulary().Values))
|
||||
for i := range tokenizer.Vocabulary().Values {
|
||||
pieces[i], _ = tokenizer.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tokenizer.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.Join("..", "testdata", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -11,24 +13,24 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
type bytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
var _ Tokenizer = (*bytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) bytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
return bytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
for _, p := range pretokenizer {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
@@ -37,15 +39,15 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEnc
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
func (bpe bytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
func (bpe bytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
func (bpe *bytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
@@ -96,7 +98,7 @@ type merge struct {
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (bpe bytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
@@ -243,7 +245,15 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
type lazyIdsString struct {
|
||||
ids []int32
|
||||
}
|
||||
|
||||
func (l lazyIdsString) LogValue() slog.Value {
|
||||
return slog.AnyValue(fmt.Sprint(l.ids))
|
||||
}
|
||||
|
||||
func (bpe bytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
@@ -267,6 +277,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
func llama(t testing.TB) bytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -12,18 +12,18 @@ import (
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePiece struct {
|
||||
type sentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
var _ Tokenizer = (*sentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
func (spm sentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
func NewSentencePiece(vocab *Vocabulary) sentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
@@ -42,17 +42,17 @@ func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePiece{
|
||||
return sentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
func (spm sentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (spm sentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
@@ -218,13 +218,13 @@ func (q *queue) Pop() interface{} {
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
func (spm sentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
func loadSentencePieceVocab(t *testing.T) sentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
type wordPiece struct {
|
||||
vocab *Vocabulary
|
||||
lowercase bool
|
||||
}
|
||||
@@ -32,8 +32,8 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
// Decode implements Tokenizer.
|
||||
func (wpm wordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
@@ -56,7 +56,7 @@ func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
func (wpm wordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
@@ -96,8 +96,8 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
// Encode implements Tokenizer.
|
||||
func (wpm wordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
@@ -151,20 +151,20 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
// Is implements Tokenizer.
|
||||
func (wpm wordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
// Vocabulary implements Tokenizer.
|
||||
func (wpm wordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
var _ Tokenizer = (*wordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) wordPiece {
|
||||
return wordPiece{
|
||||
vocab: vocab,
|
||||
lowercase: lowercase,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -39,7 +39,7 @@ func TestWordPiece(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
var wpm wordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.Join("..", "..", "tokenizer", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
|
||||
f, err = os.Open(filepath.Join("..", "..", "tokenizer", "testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"vision_tower.vision_model"`
|
||||
*TextModel `gguf:"language_model.model"`
|
||||
@@ -58,8 +59,8 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// slog.Info("XXX Config", "c", c)
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.Join("..", "..", "tokenizer", "testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user