mirror of
https://github.com/ollama/ollama.git
synced 2026-01-27 08:51:39 -05:00
Compare commits
15 Commits
v0.15.0-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26acab64b7 | ||
|
|
e0f03790b1 | ||
|
|
3ab842b0f5 | ||
|
|
b8e8ef8929 | ||
|
|
465d124183 | ||
|
|
d310e56fa3 | ||
|
|
a1ca428c90 | ||
|
|
16750865d1 | ||
|
|
f3b476c592 | ||
|
|
5267d31d56 | ||
|
|
b44f56319f | ||
|
|
0209c268bb | ||
|
|
912d984346 | ||
|
|
aae6ecbaff | ||
|
|
64737330a4 |
@@ -169,8 +169,10 @@ COPY . .
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
||||
|
||||
@@ -75,9 +75,9 @@ The `-dev` flag enables:
|
||||
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
|
||||
|
||||
```
|
||||
export CGO_CFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_LDFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
|
||||
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
```
|
||||
|
||||
@@ -2031,7 +2031,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.ConfigCmd(checkServerHeartbeat),
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
@@ -18,12 +20,32 @@ func (c *Claude) args(model string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command("claude", c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
@@ -19,6 +22,62 @@ func TestClaudeIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
|
||||
193
cmd/config/clawdbot.go
Normal file
193
cmd/config/clawdbot.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Clawdbot struct{}
|
||||
|
||||
func (c *Clawdbot) String() string { return "Clawdbot" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Clawdbot) Run(model string) error {
|
||||
if _, err := exec.LookPath("clawdbot"); err != nil {
|
||||
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("clawdbot", "gateway")
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = "http://127.0.0.1:11434/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
625
cmd/config/clawdbot_test.go
Normal file
625
cmd/config/clawdbot_test.go
Normal file
@@ -0,0 +1,625 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClawdbotIntegration(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Clawdbot" {
|
||||
t.Errorf("String() = %q, want %q", got, "Clawdbot")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEdit(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelExists(t, configPath, "mistral")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["mcp"] == nil {
|
||||
t.Error("mcp was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
// User adds custom field
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
entry["customField"] = "user-value"
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
models = cfg["models"].(map[string]any)
|
||||
providers = models["providers"].(map[string]any)
|
||||
ollama = providers["ollama"].(map[string]any)
|
||||
modelList = ollama["models"].([]any)
|
||||
entry = modelList[0].(map[string]any)
|
||||
if entry["customField"] != "user-value" {
|
||||
t.Error("custom field was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("edit replaces models list", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
original := `{"existing":"data"}`
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
c.Edit([]string{})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != original {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Error("result should be valid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type models section", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModels(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("no config returns nil", func(t *testing.T) {
|
||||
if models := c.Models(); len(models) > 0 {
|
||||
t.Errorf("expected nil/empty, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %v", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertClawdbotModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models, _ := cfg["models"].(map[string]any)
|
||||
providers, _ := models["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
t.Errorf("model %s should not exist", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
model := defaults["model"].(map[string]any)
|
||||
if model["primary"] != expected {
|
||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotPaths(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModelsEdgeCases(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at models level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model entry missing id", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for missing id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model id is not string", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for non-string id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditSchemaFields(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
}
|
||||
if cost["cacheWrite"] == nil {
|
||||
t.Error("cost.cacheWrite should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEditModelNames(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "library/model:tag")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature setting was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent was lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const testClawdbotFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
||||
"custom-agent": {"foo": "bar"}
|
||||
}
|
||||
}`
|
||||
|
||||
func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
// Verify top-level preserved
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme not preserved")
|
||||
}
|
||||
mcp := cfg["mcp"].(map[string]any)
|
||||
servers := mcp["servers"].(map[string]any)
|
||||
if servers["custom"] == nil {
|
||||
t.Error("mcp.servers.custom not preserved")
|
||||
}
|
||||
|
||||
// Verify other providers preserved
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider not preserved")
|
||||
}
|
||||
|
||||
// Verify agents preserved
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent not preserved")
|
||||
}
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_Idempotent(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
secondData, _ := os.ReadFile(configPath)
|
||||
|
||||
if string(firstData) != string(secondData) {
|
||||
t.Error("repeated edits with same models produced different results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
if i%2 == 0 {
|
||||
models = []string{"model-x", "model-y", "model-z"}
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
t.Fatalf("edit %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
||||
}
|
||||
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme lost after multiple edits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
if string(data) == original {
|
||||
foundBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup with original content not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
}
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ type Editor interface {
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Clawdbot{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
@@ -230,26 +231,28 @@ func runIntegration(name, modelName string) error {
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// ConfigCmd returns the cobra command for configuring integrations.
|
||||
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var launchFlag bool
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "config [INTEGRATION]",
|
||||
Short: "Configure an external integration to use Ollama",
|
||||
Long: `Configure an external application to use Ollama models.
|
||||
Use: "launch [INTEGRATION]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
clawdbot Clawdbot
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
|
||||
Examples:
|
||||
ollama config
|
||||
ollama config claude
|
||||
ollama config droid --launch`,
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
@@ -272,8 +275,8 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If --launch without --model, use saved config if available
|
||||
if launchFlag && modelFlag == "" {
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
@@ -334,29 +337,19 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(models, func(m string) bool {
|
||||
return !strings.HasSuffix(m, "cloud")
|
||||
}) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
|
||||
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
|
||||
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
if launchFlag {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
return runIntegration(name, models[0])
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -81,17 +81,17 @@ func TestHasLocalModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd(t *testing.T) {
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := ConfigCmd(mockCheck)
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "config [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -107,9 +107,9 @@ func TestConfigCmd(t *testing.T) {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
launchFlag := cmd.Flags().Lookup("launch")
|
||||
if launchFlag == nil {
|
||||
t.Error("--launch flag should exist")
|
||||
configFlag := cmd.Flags().Lookup("config")
|
||||
if configFlag == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -158,11 +158,11 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd_NilHeartbeat(t *testing.T) {
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := ConfigCmd(nil)
|
||||
cmd := LaunchCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("ConfigCmd returned nil")
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
|
||||
@@ -105,17 +105,26 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if displayName, ok := cfgMap["name"].(string); ok {
|
||||
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
"name": fmt.Sprintf("%s [Ollama]", model),
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,3 +210,15 @@ func (o *OpenCode) Models() []string {
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -161,6 +161,76 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add custom fields to the model entry (simulating user edits)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
entry["_myPref"] = "custom-value"
|
||||
entry["_myNum"] = 42
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit — should preserve custom fields
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider = cfg["provider"].(map[string]any)
|
||||
ollama = provider["ollama"].(map[string]any)
|
||||
models = ollama["models"].(map[string]any)
|
||||
entry = models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["_myPref"] != "custom-value" {
|
||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
||||
}
|
||||
if entry["_myNum"] != float64(42) {
|
||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
||||
}
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
|
||||
// _launch marker should be added
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
||||
}
|
||||
// [Ollama] suffix should be stripped
|
||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
@@ -465,7 +465,7 @@ func confirmPrompt(prompt string) (bool, error) {
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
|
||||
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
|
||||
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
|
||||
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
|
||||
if kvFirst {
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
|
||||
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extractK {
|
||||
// Slice K: [n_head, qk_nope, kv_lora_rank]
|
||||
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
// Transpose to [n_head, kv_lora_rank, qk_nope]
|
||||
tt, err = tensor.Transpose(tt, 0, 2, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
} else {
|
||||
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
|
||||
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
|
||||
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
numHeads := int(p.NumAttentionHeads)
|
||||
kvFirst := true
|
||||
if len(t.Shape()) == 2 {
|
||||
switch {
|
||||
case int(t.Shape()[0]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[1]) / kvPerHead
|
||||
}
|
||||
kvFirst = true
|
||||
case int(t.Shape()[1]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[0]) / kvPerHead
|
||||
}
|
||||
kvFirst = false
|
||||
default:
|
||||
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
kTensor := t.Clone()
|
||||
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
|
||||
WriterTo: kTensor,
|
||||
})
|
||||
|
||||
vTensor := t.Clone()
|
||||
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
|
||||
WriterTo: vTensor,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
|
||||
@@ -4,16 +4,6 @@ title: Anthropic compatibility
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
|
||||
|
||||
### Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
|
||||
|
||||
Download a model before use:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
|
||||
|
||||
```shell
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
41
docs/cli.mdx
41
docs/cli.mdx
@@ -8,6 +8,47 @@ title: CLI Reference
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
### Launch integrations
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
|
||||
|
||||
#### Supported integrations
|
||||
|
||||
- **OpenCode** - Open-source coding assistant
|
||||
- **Claude Code** - Anthropic's agentic coding tool
|
||||
- **Codex** - OpenAI's coding assistant
|
||||
- **Droid** - Factory's AI coding agent
|
||||
|
||||
#### Examples
|
||||
|
||||
Launch an integration interactively:
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Launch a specific integration:
|
||||
|
||||
```
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Launch with a specific model:
|
||||
|
||||
```
|
||||
ollama launch claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Configure without launching:
|
||||
|
||||
```
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
#### Multiline input
|
||||
|
||||
For multiline input, you can wrap text with `"""`:
|
||||
|
||||
@@ -3,8 +3,6 @@ title: Cloud
|
||||
sidebarTitle: Cloud
|
||||
---
|
||||
|
||||
<Info>Ollama's cloud is currently in preview.</Info>
|
||||
|
||||
## Cloud Models
|
||||
|
||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||
|
||||
@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
## Setting context length
|
||||
|
||||
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
|
||||
### CLI
|
||||
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
|
||||
```
|
||||
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
|
||||
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
|
||||
```
|
||||
|
||||
### Check allocated context length and model offloading
|
||||
|
||||
@@ -102,18 +102,20 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
"/integrations/clawdbot",
|
||||
"/integrations/cline",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@ sidebarTitle: Welcome
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Quickstart" icon="rocket" href="/quickstart">
|
||||
Get up and running with your first model
|
||||
Get up and running with your first model or integrate Ollama with your favorite tools
|
||||
</Card>
|
||||
<Card
|
||||
title="Download Ollama"
|
||||
|
||||
@@ -4,7 +4,7 @@ title: Claude Code
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
||||
|
||||

|
||||
|
||||
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
48
docs/integrations/clawdbot.mdx
Normal file
48
docs/integrations/clawdbot.mdx
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
title: Clawdbot
|
||||
---
|
||||
|
||||
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Clawdbot](https://clawd.bot/)
|
||||
|
||||
```bash
|
||||
npm install -g clawdbot@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
clawdbot onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch clawdbot
|
||||
```
|
||||
|
||||
This configures Clawdbot to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch clawdbot --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
@@ -13,7 +13,21 @@ npm install -g @openai/codex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
|
||||
|
||||
### Quick setup
|
||||
|
||||
```
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch codex --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
To use `codex` with Ollama, use the `--oss` flag:
|
||||
|
||||
|
||||
@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
|
||||
curl -fsSL https://app.factory.ai/cli | sh
|
||||
```
|
||||
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch droid
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a local configuration block to `~/.factory/config.json`:
|
||||
|
||||
```json
|
||||
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
|
||||
}
|
||||
```
|
||||
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
|
||||
106
docs/integrations/opencode.mdx
Normal file
106
docs/integrations/opencode.mdx
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
title: OpenCode
|
||||
---
|
||||
|
||||
OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
|
||||
## Install
|
||||
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install.sh | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
|
||||
### Launch with a specific model
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: nobody <>
|
||||
Date: Sat, 24 Jan 2026 02:31:01 +0000
|
||||
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
|
||||
|
||||
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
|
||||
uses head size 576 with gqa_ratio 4, which was previously only supported
|
||||
for gqa_ratio 16 (DeepSeek).
|
||||
|
||||
Metal changes:
|
||||
- Enable head size 576 for flash attention
|
||||
- Increase simdgroups to 8 for large heads (>=512)
|
||||
- Add case 8 kernel dispatch for 8 simdgroups
|
||||
|
||||
CUDA changes:
|
||||
- Add gqa_ratio 4 support for head 576/512
|
||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||
- Add MMA config cases for ncols 4
|
||||
- Add template instances for ncols2=4
|
||||
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
|
||||
---
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
|
||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
|
||||
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
10 files changed, 64 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 7bd1044c1..3dea2205e 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
- // Wide version of KQ_C is column-major => swap A and B.
|
||||
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+ if constexpr (cols_per_warp == 8) {
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+ } else {
|
||||
+ // Wide version of KQ_C is column-major
|
||||
+#if defined(AMD_WMMA_AVAILABLE)
|
||||
+ // RDNA matrix C is column-major.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+#else
|
||||
+ // swap A and B for CUDA.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
+ }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
+#ifdef VOLTA_MMA_AVAILABLE
|
||||
+ if (ncols1*ncols2 < 32) {
|
||||
+ NO_DEVICE_CODE;
|
||||
+ return;
|
||||
+ }
|
||||
+#endif // VOLTA_MMA_AVAILABLE
|
||||
+
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
+
|
||||
+// For GLM 4.7 Flash
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 7c4d6fe67..371be7442 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index 015540666..1693479cb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
- GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
+ if (gqa_ratio % 16 == 0) {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ } else {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
+ }
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
index 2074e954a..517993cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
index 24c64cf00..97b19c67a 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
index 1ada657f1..989626dfa 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
index 86d4ffae2..173de7aac 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index f24270bb1..7b5ee968c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
- op->src[0]->ne[0] != 256) {
|
||||
- return false;
|
||||
- }
|
||||
- if (op->src[0]->ne[0] == 576) {
|
||||
- // DeepSeek sizes
|
||||
- // TODO: disabled for now, until optmized
|
||||
+ op->src[0]->ne[0] != 256 &&
|
||||
+ op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index e99c1763f..80864f303 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
- int32_t nsg = 4;
|
||||
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
// Wide version of KQ_C is column-major => swap A and B.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -39,6 +39,13 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
// validation after tensors have been loaded. If validation fails, model
|
||||
// loading will fail with the returned error.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
|
||||
if validator, ok := m.(Validator); ok {
|
||||
if err := validator.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
@@ -47,7 +50,9 @@ type Attention struct {
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
// MLA absorption: absorb K projection into query
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
@@ -217,7 +223,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
@@ -236,7 +241,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
@@ -279,6 +284,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
|
||||
return ErrOldModelFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
|
||||
73
model/models/glm4moelite/model_test.go
Normal file
73
model/models/glm4moelite/model_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid model with KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing KB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing both KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil Attention is ok",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: nil},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.model.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && err != ErrOldModelFormat {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -17,12 +18,34 @@ const (
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
// ministralEvent represents an event emitted during parsing
|
||||
type ministralEvent interface {
|
||||
isMinistralEvent()
|
||||
}
|
||||
|
||||
type ministralEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ministralEventThinking struct {
|
||||
thinking string
|
||||
}
|
||||
|
||||
type ministralEventToolCall struct {
|
||||
name string
|
||||
args string // raw JSON string
|
||||
}
|
||||
|
||||
func (ministralEventContent) isMinistralEvent() {}
|
||||
func (ministralEventThinking) isMinistralEvent() {}
|
||||
func (ministralEventToolCall) isMinistralEvent() {}
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
pendingToolName string // stores tool name while collecting args
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
const (
|
||||
ministralToolCallsTag = "[TOOL_CALLS]"
|
||||
ministralThinkTag = "[THINK]"
|
||||
ministralThinkEndTag = "[/THINK]"
|
||||
ministralArgsTag = "[ARGS]"
|
||||
)
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. The second return value indicates
|
||||
// whether to keep looping (true when state transitions, false when waiting
|
||||
// for more data).
|
||||
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
|
||||
var events []ministralEvent
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Check for [TOOL_CALLS] tag
|
||||
if strings.Contains(bufStr, ministralToolCallsTag) {
|
||||
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for [THINK] tag
|
||||
if strings.Contains(bufStr, ministralThinkTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
|
||||
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
|
||||
overlapThink := overlap(bufStr, ministralThinkTag)
|
||||
maxOverlap := max(overlapToolCalls, overlapThink)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
// Withhold the potential partial tag
|
||||
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
|
||||
trailingWS := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWS
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit content but withhold trailing whitespace
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralThinkEndTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
|
||||
thinkingContent := split[0]
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
p.buffer.WriteString(after)
|
||||
if len(thinkingContent) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: thinkingContent})
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
call := api.ToolCall{
|
||||
// Check for partial overlap with [/THINK]
|
||||
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit all thinking content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolName:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralArgsTag) {
|
||||
split := strings.SplitN(bufStr, ministralArgsTag, 2)
|
||||
toolName := split[0]
|
||||
after := split[1]
|
||||
p.pendingToolName = toolName
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolArgs
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolArgs:
|
||||
bufStr := p.buffer.String()
|
||||
jsonEnd := findJSONEnd(bufStr)
|
||||
|
||||
if jsonEnd != -1 {
|
||||
jsonStr := bufStr[:jsonEnd+1]
|
||||
remaining := bufStr[jsonEnd+1:]
|
||||
|
||||
events = append(events, ministralEventToolCall{
|
||||
name: p.pendingToolName,
|
||||
args: jsonStr,
|
||||
})
|
||||
|
||||
p.pendingToolName = ""
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unexpected ministral event")
|
||||
}
|
||||
}
|
||||
|
||||
// parseEvents loops calling eat() until it returns false
|
||||
func (p *MinistralParser) parseEvents() []ministralEvent {
|
||||
var all []ministralEvent
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []ministralEvent
|
||||
events, keepLooping = p.eat()
|
||||
all = append(all, events...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentBuilder, thinkingBuilder strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case ministralEventContent:
|
||||
contentBuilder.WriteString(e.content)
|
||||
case ministralEventThinking:
|
||||
thinkingBuilder.WriteString(e.thinking)
|
||||
case ministralEventToolCall:
|
||||
// Validate tool exists
|
||||
tool, toolErr := toolByName(p.tools, e.name)
|
||||
if toolErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
|
||||
}
|
||||
// Parse JSON arguments
|
||||
var args api.ToolCallFunctionArguments
|
||||
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Name: tool.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
})
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
// findJSONEnd finds the index of the closing brace that completes a JSON object.
|
||||
// It properly handles nested objects, arrays, and strings (including escaped characters).
|
||||
// Returns -1 if the JSON is not yet complete.
|
||||
func findJSONEnd(s string) int {
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range s {
|
||||
if inString {
|
||||
switch {
|
||||
case escaped:
|
||||
// If the previous character was a backslash, skip this character
|
||||
escaped = false
|
||||
case r == '\\':
|
||||
// Mark the next character as escaped
|
||||
escaped = true
|
||||
case r == '"':
|
||||
// End of string literal
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
// Start of string literal
|
||||
inString = true
|
||||
case '{', '[':
|
||||
// Increase nesting level for objects and arrays
|
||||
depth++
|
||||
case '}', ']':
|
||||
// Decrease nesting level
|
||||
depth--
|
||||
if depth == 0 {
|
||||
// Reached the end of the root JSON structure
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
545
model/parsers/ministral_test.go
Normal file
545
model/parsers/ministral_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMinistralParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []ministralEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
tools []api.Tool
|
||||
steps []step
|
||||
think bool // whether to enable thinking support
|
||||
}{
|
||||
// Content streaming
|
||||
{
|
||||
desc: "simple content",
|
||||
steps: []step{
|
||||
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "Hello, how can I help you?"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming content word by word",
|
||||
steps: []step{
|
||||
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
|
||||
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
|
||||
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Simple tool calls
|
||||
{
|
||||
desc: "simple tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with nested object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with deeply nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with array of objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with escaped quotes in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with braces inside string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty JSON object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "no_args", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON with newlines in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "backslash in string value",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content after tool call
|
||||
{
|
||||
desc: "content after tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// NOTE: It's unclear if this is valid Ministral output, but the parser
|
||||
// currently treats text after a tool call as regular content. This test
|
||||
// documents that behavior so we notice if it changes.
|
||||
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": 1}`},
|
||||
ministralEventContent{content: "some content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Multiple tool calls
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "get_weather"}},
|
||||
{Function: api.ToolFunction{Name: "get_time"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
|
||||
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls streamed separately",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "tool_a"}},
|
||||
{Function: api.ToolFunction{Name: "tool_b"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
|
||||
}},
|
||||
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Streaming tool calls
|
||||
{
|
||||
desc: "streaming tool call with nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
|
||||
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
|
||||
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
|
||||
{input: `]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming with incomplete JSON waits for completion",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
|
||||
{input: `"a": {`, wantEvents: []ministralEvent{}},
|
||||
{input: `"b": 1`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Partial tag handling
|
||||
{
|
||||
desc: "partial tool tag fakeout",
|
||||
steps: []step{
|
||||
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
|
||||
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tag split across chunks",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_", wantEvents: []ministralEvent{}},
|
||||
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "hello"},
|
||||
ministralEventToolCall{name: "get_weather", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace between content and tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tabs and newlines before tool call are trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space before tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
|
||||
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace before THINK tag is trimmed",
|
||||
steps: []step{
|
||||
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing newline withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Thinking support
|
||||
{
|
||||
desc: "thinking content",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking here"},
|
||||
}},
|
||||
{input: "content after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking with whitespace after end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "my thoughts"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space after think end tag is trimmed",
|
||||
think: true,
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space
|
||||
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial think end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag fakeout",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking then tool call",
|
||||
think: true,
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "let me think"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content then THINK tag transition
|
||||
{
|
||||
desc: "content then think tag",
|
||||
steps: []step{
|
||||
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "more"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Unicode handling
|
||||
{
|
||||
desc: "unicode content",
|
||||
steps: []step{
|
||||
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "你好 🌍 مرحبا"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "unicode in tool args",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := MinistralParser{}
|
||||
parser.hasThinkingSupport = tc.think
|
||||
parser.Init(tc.tools, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_Errors(t *testing.T) {
|
||||
t.Run("unknown tool returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown tool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindJSONEnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
input: `{"a": 1}`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
input: `{"a": {"b": 2}}`,
|
||||
expected: 14,
|
||||
},
|
||||
{
|
||||
name: "array inside object",
|
||||
input: `{"items": [1, 2, 3]}`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "braces in string",
|
||||
input: `{"template": "Hello {name}!"}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
input: `{"msg": "say \"hi\""}`,
|
||||
expected: 20,
|
||||
},
|
||||
{
|
||||
name: "incomplete object",
|
||||
input: `{"a": {"b": 1}`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
input: `{"a": {"b": {"c": {"d": 1}}}}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "object with trailing content",
|
||||
input: `{"a": 1} extra`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
input: `[{"a": 1}, {"b": 2}]`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "escaped backslash before quote",
|
||||
input: `{"path": "C:\\"}`,
|
||||
expected: 15,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "no opening brace",
|
||||
input: "hello world",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "only opening brace",
|
||||
input: "{",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "unclosed string",
|
||||
input: `{"key": "unclosed`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "double escaped backslash then quote",
|
||||
input: `{"path": "C:\\\\"}`,
|
||||
expected: 17,
|
||||
},
|
||||
{
|
||||
name: "unicode in key and value",
|
||||
input: `{"키": "값"}`,
|
||||
expected: 13,
|
||||
},
|
||||
{
|
||||
name: "nested arrays",
|
||||
input: `{"matrix": [[1, 2], [3, 4]]}`,
|
||||
expected: 27,
|
||||
},
|
||||
{
|
||||
name: "mixed nesting",
|
||||
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
|
||||
expected: 31,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findJSONEnd(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasToolSupport(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
|
||||
p := &MinistralParser{hasThinkingSupport: false}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
|
||||
p = &MinistralParser{hasThinkingSupport: true}
|
||||
if !p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return true")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package parsers
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
@@ -114,3 +115,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
// overlap returns the longest overlap between the suffix of s and the prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
VOL_NAME=${VOL_NAME:-"Ollama"}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -56,6 +56,12 @@ function checkEnv {
|
||||
|
||||
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
if (-not $env:CGO_CFLAGS) {
|
||||
$env:CGO_CFLAGS = "-O3"
|
||||
}
|
||||
if (-not $env:CGO_CXXFLAGS) {
|
||||
$env:CGO_CXXFLAGS = "-O3"
|
||||
}
|
||||
Write-Output "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)
|
||||
|
||||
@@ -95,6 +95,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
} else if strings.Contains(name, "attn_k_b.weight") ||
|
||||
strings.Contains(name, "attn_v_b.weight") ||
|
||||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
|
||||
strings.Contains(name, "attn_q_a.weight") ||
|
||||
strings.Contains(name, "attn_q_b.weight") {
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
n_layer := qs.nFfnDown
|
||||
|
||||
Reference in New Issue
Block a user