mirror of
https://github.com/ollama/ollama.git
synced 2026-01-28 09:20:33 -05:00
Compare commits
29 Commits
fix-imageg
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b62c41060 | ||
|
|
26acab64b7 | ||
|
|
e0f03790b1 | ||
|
|
3ab842b0f5 | ||
|
|
b8e8ef8929 | ||
|
|
465d124183 | ||
|
|
d310e56fa3 | ||
|
|
a1ca428c90 | ||
|
|
16750865d1 | ||
|
|
f3b476c592 | ||
|
|
5267d31d56 | ||
|
|
b44f56319f | ||
|
|
0209c268bb | ||
|
|
912d984346 | ||
|
|
aae6ecbaff | ||
|
|
64737330a4 | ||
|
|
2eda97f1c3 | ||
|
|
66831dcf70 | ||
|
|
1044b0419a | ||
|
|
771d9280ec | ||
|
|
862bc0a3bf | ||
|
|
c01608b6a1 | ||
|
|
199c41e16e | ||
|
|
3b3bf6c217 | ||
|
|
f52c21f457 | ||
|
|
b5d0f72f16 | ||
|
|
148a1be0a3 | ||
|
|
d6dd430abd | ||
|
|
ae78112c50 |
@@ -169,8 +169,10 @@ COPY . .
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
||||
|
||||
@@ -75,9 +75,9 @@ The `-dev` flag enables:
|
||||
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
|
||||
|
||||
```
|
||||
export CGO_CFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_LDFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
|
||||
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
```
|
||||
|
||||
12
cmd/cmd.go
12
cmd/cmd.go
@@ -35,6 +35,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -1018,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
arch, _ := resp.ModelInfo["general.architecture"].(string)
|
||||
if arch != "" {
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
}
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
@@ -1029,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
if paramStr != "" {
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
@@ -2026,6 +2031,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
60
cmd/config/claude.go
Normal file
60
cmd/config/claude.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
101
cmd/config/claude_test.go
Normal file
101
cmd/config/claude_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
195
cmd/config/clawdbot.go
Normal file
195
cmd/config/clawdbot.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Clawdbot struct{}
|
||||
|
||||
func (c *Clawdbot) String() string { return "Clawdbot" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Clawdbot) Run(model string) error {
|
||||
if _, err := exec.LookPath("clawdbot"); err != nil {
|
||||
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("clawdbot", "gateway")
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
625
cmd/config/clawdbot_test.go
Normal file
625
cmd/config/clawdbot_test.go
Normal file
@@ -0,0 +1,625 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClawdbotIntegration(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Clawdbot" {
|
||||
t.Errorf("String() = %q, want %q", got, "Clawdbot")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEdit(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelExists(t, configPath, "mistral")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["mcp"] == nil {
|
||||
t.Error("mcp was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
// User adds custom field
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
entry["customField"] = "user-value"
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
models = cfg["models"].(map[string]any)
|
||||
providers = models["providers"].(map[string]any)
|
||||
ollama = providers["ollama"].(map[string]any)
|
||||
modelList = ollama["models"].([]any)
|
||||
entry = modelList[0].(map[string]any)
|
||||
if entry["customField"] != "user-value" {
|
||||
t.Error("custom field was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("edit replaces models list", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
original := `{"existing":"data"}`
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
c.Edit([]string{})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != original {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Error("result should be valid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type models section", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModels(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("no config returns nil", func(t *testing.T) {
|
||||
if models := c.Models(); len(models) > 0 {
|
||||
t.Errorf("expected nil/empty, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %v", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertClawdbotModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models, _ := cfg["models"].(map[string]any)
|
||||
providers, _ := models["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
t.Errorf("model %s should not exist", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
model := defaults["model"].(map[string]any)
|
||||
if model["primary"] != expected {
|
||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotPaths(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModelsEdgeCases(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at models level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model entry missing id", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for missing id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model id is not string", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for non-string id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditSchemaFields(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
}
|
||||
if cost["cacheWrite"] == nil {
|
||||
t.Error("cost.cacheWrite should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEditModelNames(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "library/model:tag")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature setting was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent was lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const testClawdbotFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
||||
"custom-agent": {"foo": "bar"}
|
||||
}
|
||||
}`
|
||||
|
||||
func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
// Verify top-level preserved
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme not preserved")
|
||||
}
|
||||
mcp := cfg["mcp"].(map[string]any)
|
||||
servers := mcp["servers"].(map[string]any)
|
||||
if servers["custom"] == nil {
|
||||
t.Error("mcp.servers.custom not preserved")
|
||||
}
|
||||
|
||||
// Verify other providers preserved
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider not preserved")
|
||||
}
|
||||
|
||||
// Verify agents preserved
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent not preserved")
|
||||
}
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_Idempotent(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
secondData, _ := os.ReadFile(configPath)
|
||||
|
||||
if string(firstData) != string(secondData) {
|
||||
t.Error("repeated edits with same models produced different results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
if i%2 == 0 {
|
||||
models = []string{"model-x", "model-y", "model-z"}
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
t.Fatalf("edit %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
||||
}
|
||||
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme lost after multiple edits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
if string(data) == original {
|
||||
foundBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup with original content not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
}
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
61
cmd/config/codex.go
Normal file
61
cmd/config/codex.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
cmd/config/codex_test.go
Normal file
28
cmd/config/codex_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
115
cmd/config/config.go
Normal file
115
cmd/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Package config provides integration configuration for external coding tools
|
||||
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
cfg.Integrations = make(map[string]*integration)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func save(cfg *config) error {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
373
cmd/config/config_test.go
Normal file
373
cmd/config/config_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(config.Models) != len(models) {
|
||||
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||
}
|
||||
for i, m := range models {
|
||||
if config.Models[i] != m {
|
||||
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||
config := &integration{Models: []string{}}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "" {
|
||||
t.Errorf("expected empty string, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-x" {
|
||||
t.Errorf("expected model-x, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
defaultModel1 = config1.Models[0]
|
||||
}
|
||||
defaultModel2 := ""
|
||||
if len(config2.Models) > 0 {
|
||||
defaultModel2 = config2.Models[0]
|
||||
}
|
||||
if defaultModel1 != "model-1" {
|
||||
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||
}
|
||||
if defaultModel2 != "model-2" {
|
||||
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIntegrations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 0 {
|
||||
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
if config.Models == nil {
|
||||
// nil is acceptable
|
||||
} else if len(config.Models) != 0 {
|
||||
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected non-nil config")
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
t.Error("expected non-nil Integrations map")
|
||||
}
|
||||
if len(cfg.Integrations) != 0 {
|
||||
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loads existing config", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["test"] == nil {
|
||||
t.Fatal("expected test integration")
|
||||
}
|
||||
if len(cfg.Integrations["test"].Models) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||
|
||||
_, err := load()
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupted JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("creates config file", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"test": {Models: []string{"model-a", "model-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
path, _ := configPath()
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("config file was not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||
"codex": {Models: []string{"qwen2.5"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(loaded.Integrations) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||
}
|
||||
if loaded.Integrations["claude"] == nil {
|
||||
t.Error("missing claude integration")
|
||||
}
|
||||
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
186
cmd/config/droid.go
Normal file
186
cmd/config/droid.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||
type droidSettings struct {
|
||||
CustomModels []modelEntry `json:"customModels"`
|
||||
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||
}
|
||||
|
||||
type sessionSettings struct {
|
||||
Model string `json:"model"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
type modelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read file once, unmarshal twice:
|
||||
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||
settingsMap := make(map[string]any)
|
||||
var settings droidSettings
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||
for _, raw := range rawModels {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
if m["apiKey"] != "ollama" {
|
||||
nonOllamaModels = append(nonOllamaModels, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||
|
||||
// Update session default settings (preserve unknown fields in the nested object)
|
||||
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings droidSettings
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, m := range settings.CustomModels {
|
||||
if m.APIKey == "ollama" {
|
||||
result = append(result, m.Model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
1302
cmd/config/droid_test.go
Normal file
1302
cmd/config/droid_test.go
Normal file
File diff suppressed because it is too large
Load Diff
99
cmd/config/files.go
Normal file
99
cmd/config/files.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
502
cmd/config/files_test.go
Normal file
502
cmd/config/files_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := backupToTmp(path)
|
||||
if err != nil {
|
||||
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
355
cmd/config/integrations.go
Normal file
355
cmd/config/integrations.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Runners execute the launching of a model with the integration - claude, codex
|
||||
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
||||
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files (supports multi-model selection)
|
||||
type Editor interface {
|
||||
// Paths returns the paths to the config files for the integration
|
||||
Paths() []string
|
||||
// Edit updates the config files for the integration with the given models
|
||||
Edit(models []string) error
|
||||
// Models returns the models currently configured for the integration
|
||||
// TODO(parthsareen): add error return to Models()
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Clawdbot{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, selectItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
clawdbot Clawdbot
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
} else {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r, ok := integrations[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
if m != modelFlag {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, isEditor := r.(Editor); isEditor {
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
||||
}
|
||||
}
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
188
cmd/config/integrations_test.go
Normal file
188
cmd/config/integrations_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFound bool
|
||||
wantName string
|
||||
}{
|
||||
{"claude lowercase", "claude", true, "Claude Code"},
|
||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||
{"codex", "codex", true, "Codex"},
|
||||
{"droid", "droid", true, "Droid"},
|
||||
{"opencode", "opencode", true, "OpenCode"},
|
||||
{"unknown integration", "unknown", false, ""},
|
||||
{"empty string", "", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, found := integrations[strings.ToLower(tt.input)]
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
||||
}
|
||||
if found && r.String() != tt.wantName {
|
||||
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
t.Fatalf("integration %q not found in registry", name)
|
||||
}
|
||||
if r.String() == "" {
|
||||
t.Error("integration.String() should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{"empty list", []string{}, false},
|
||||
{"single local model", []string{"llama3.2"}, true},
|
||||
{"single cloud model", []string{"cloud-model"}, false},
|
||||
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
||||
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
||||
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
||||
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
configFlag := cmd.Flags().Lookup("config")
|
||||
if configFlag == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
reason string
|
||||
}{
|
||||
{"empty list", []string{}, false, "empty list has no local models"},
|
||||
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
||||
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
||||
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
||||
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
||||
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
||||
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
226
cmd/config/opencode.go
Normal file
226
cmd/config/opencode.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
}
|
||||
return false
|
||||
}
|
||||
507
cmd/config/opencode_test.go
Normal file
507
cmd/config/opencode_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
os.RemoveAll(stateDir)
|
||||
}
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
if provider["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("update existing model", func(t *testing.T) {
|
||||
cleanup()
|
||||
o.Edit([]string{"llama3.2"})
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["keybindings"] == nil {
|
||||
t.Error("keybindings was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||
})
|
||||
|
||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
if len(state["favorite"].([]any)) != 1 {
|
||||
t.Error("favorite was modified")
|
||||
}
|
||||
if state["variant"].(map[string]any)["a"] != "b" {
|
||||
t.Error("variant was modified")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
recent := state["recent"].([]any)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("remove model", func(t *testing.T) {
|
||||
cleanup()
|
||||
// First add two models
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add custom fields to the model entry (simulating user edits)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
entry["_myPref"] = "custom-value"
|
||||
entry["_myNum"] = 42
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit — should preserve custom fields
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider = cfg["provider"].(map[string]any)
|
||||
ollama = provider["ollama"].(map[string]any)
|
||||
models = ollama["models"].(map[string]any)
|
||||
entry = models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["_myPref"] != "custom-value" {
|
||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
||||
}
|
||||
if entry["_myNum"] != float64(42) {
|
||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
||||
}
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
|
||||
// _launch marker should be added
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
||||
}
|
||||
// [Ollama] suffix should be stripped
|
||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
// Add a non-Ollama model manually
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||
})
|
||||
}
|
||||
|
||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider not found")
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("ollama provider not found")
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("models not found")
|
||||
}
|
||||
if models[model] == nil {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
return // No provider means no model
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
return // No ollama means no model
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
return // No models means no model
|
||||
}
|
||||
if models[model] != nil {
|
||||
t.Errorf("model %s should not exist but was found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("recent not found")
|
||||
}
|
||||
if index >= len(recent) {
|
||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||
}
|
||||
entry, ok := recent[index].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("entry is not a map")
|
||||
}
|
||||
if entry["providerID"] != providerID {
|
||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||
}
|
||||
if entry["modelID"] != modelID {
|
||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case tests for opencode.go
|
||||
|
||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Should not panic - corrupted JSON should be treated as empty
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid JSON was created
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid state was created
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider is now correct type
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||
}
|
||||
if provider["ollama"] == nil {
|
||||
t.Error("ollama provider should be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||
}
|
||||
|
||||
// The function should handle this gracefully
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
|
||||
// recent should be properly set after setup
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||
} else if len(recent) == 0 {
|
||||
t.Logf("Note: recent is empty (documenting behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := o.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Model name with special characters (though unusual)
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit([]string{specialModel})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored correctly
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Model should be accessible
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := o.Models()
|
||||
if len(models) > 0 {
|
||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||
}
|
||||
}
|
||||
499
cmd/config/selector.go
Normal file
499
cmd/config/selector.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
913
cmd/config/selector_test.go
Normal file
913
cmd/config/selector_test.go
Normal file
@@ -0,0 +1,913 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
|
||||
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
|
||||
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
|
||||
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
|
||||
if kvFirst {
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
|
||||
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extractK {
|
||||
// Slice K: [n_head, qk_nope, kv_lora_rank]
|
||||
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
// Transpose to [n_head, kv_lora_rank, qk_nope]
|
||||
tt, err = tensor.Transpose(tt, 0, 2, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
} else {
|
||||
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
|
||||
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
|
||||
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
numHeads := int(p.NumAttentionHeads)
|
||||
kvFirst := true
|
||||
if len(t.Shape()) == 2 {
|
||||
switch {
|
||||
case int(t.Shape()[0]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[1]) / kvPerHead
|
||||
}
|
||||
kvFirst = true
|
||||
case int(t.Shape()[1]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[0]) / kvPerHead
|
||||
}
|
||||
kvFirst = false
|
||||
default:
|
||||
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
kTensor := t.Clone()
|
||||
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
|
||||
WriterTo: kTensor,
|
||||
})
|
||||
|
||||
vTensor := t.Clone()
|
||||
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
|
||||
WriterTo: vTensor,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
|
||||
@@ -4,16 +4,6 @@ title: Anthropic compatibility
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
|
||||
|
||||
### Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
|
||||
|
||||
Download a model before use:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
|
||||
|
||||
```shell
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
41
docs/cli.mdx
41
docs/cli.mdx
@@ -8,6 +8,47 @@ title: CLI Reference
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
### Launch integrations
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
|
||||
|
||||
#### Supported integrations
|
||||
|
||||
- **OpenCode** - Open-source coding assistant
|
||||
- **Claude Code** - Anthropic's agentic coding tool
|
||||
- **Codex** - OpenAI's coding assistant
|
||||
- **Droid** - Factory's AI coding agent
|
||||
|
||||
#### Examples
|
||||
|
||||
Launch an integration interactively:
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Launch a specific integration:
|
||||
|
||||
```
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Launch with a specific model:
|
||||
|
||||
```
|
||||
ollama launch claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Configure without launching:
|
||||
|
||||
```
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
#### Multiline input
|
||||
|
||||
For multiline input, you can wrap text with `"""`:
|
||||
|
||||
@@ -3,8 +3,6 @@ title: Cloud
|
||||
sidebarTitle: Cloud
|
||||
---
|
||||
|
||||
<Info>Ollama's cloud is currently in preview.</Info>
|
||||
|
||||
## Cloud Models
|
||||
|
||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||
|
||||
@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
## Setting context length
|
||||
|
||||
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
|
||||
### CLI
|
||||
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
|
||||
```
|
||||
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
|
||||
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
|
||||
```
|
||||
|
||||
### Check allocated context length and model offloading
|
||||
|
||||
@@ -102,18 +102,20 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
"/integrations/clawdbot",
|
||||
"/integrations/cline",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@ sidebarTitle: Welcome
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Quickstart" icon="rocket" href="/quickstart">
|
||||
Get up and running with your first model
|
||||
Get up and running with your first model or integrate Ollama with your favorite tools
|
||||
</Card>
|
||||
<Card
|
||||
title="Download Ollama"
|
||||
|
||||
@@ -4,7 +4,7 @@ title: Claude Code
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
||||
|
||||

|
||||
|
||||
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
48
docs/integrations/clawdbot.mdx
Normal file
48
docs/integrations/clawdbot.mdx
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
title: Clawdbot
|
||||
---
|
||||
|
||||
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Clawdbot](https://clawd.bot/)
|
||||
|
||||
```bash
|
||||
npm install -g clawdbot@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
clawdbot onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch clawdbot
|
||||
```
|
||||
|
||||
This configures Clawdbot to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch clawdbot --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
@@ -13,7 +13,21 @@ npm install -g @openai/codex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
|
||||
|
||||
### Quick setup
|
||||
|
||||
```
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch codex --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
To use `codex` with Ollama, use the `--oss` flag:
|
||||
|
||||
|
||||
@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
|
||||
curl -fsSL https://app.factory.ai/cli | sh
|
||||
```
|
||||
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch droid
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a local configuration block to `~/.factory/config.json`:
|
||||
|
||||
```json
|
||||
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
|
||||
}
|
||||
```
|
||||
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
|
||||
106
docs/integrations/opencode.mdx
Normal file
106
docs/integrations/opencode.mdx
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
title: OpenCode
|
||||
---
|
||||
|
||||
OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
|
||||
## Install
|
||||
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install.sh | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
|
||||
### Launch with a specific model
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
@@ -143,6 +144,7 @@ var (
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"lfm2.5-thinking",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
@@ -263,6 +265,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: nobody <>
|
||||
Date: Sat, 24 Jan 2026 02:31:01 +0000
|
||||
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
|
||||
|
||||
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
|
||||
uses head size 576 with gqa_ratio 4, which was previously only supported
|
||||
for gqa_ratio 16 (DeepSeek).
|
||||
|
||||
Metal changes:
|
||||
- Enable head size 576 for flash attention
|
||||
- Increase simdgroups to 8 for large heads (>=512)
|
||||
- Add case 8 kernel dispatch for 8 simdgroups
|
||||
|
||||
CUDA changes:
|
||||
- Add gqa_ratio 4 support for head 576/512
|
||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||
- Add MMA config cases for ncols 4
|
||||
- Add template instances for ncols2=4
|
||||
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
|
||||
---
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
|
||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
|
||||
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
10 files changed, 64 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 7bd1044c1..3dea2205e 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
- // Wide version of KQ_C is column-major => swap A and B.
|
||||
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+ if constexpr (cols_per_warp == 8) {
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+ } else {
|
||||
+ // Wide version of KQ_C is column-major
|
||||
+#if defined(AMD_WMMA_AVAILABLE)
|
||||
+ // RDNA matrix C is column-major.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+#else
|
||||
+ // swap A and B for CUDA.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
+ }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
+#ifdef VOLTA_MMA_AVAILABLE
|
||||
+ if (ncols1*ncols2 < 32) {
|
||||
+ NO_DEVICE_CODE;
|
||||
+ return;
|
||||
+ }
|
||||
+#endif // VOLTA_MMA_AVAILABLE
|
||||
+
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
+
|
||||
+// For GLM 4.7 Flash
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 7c4d6fe67..371be7442 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index 015540666..1693479cb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
- GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
+ if (gqa_ratio % 16 == 0) {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ } else {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
+ }
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
index 2074e954a..517993cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
index 24c64cf00..97b19c67a 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
index 1ada657f1..989626dfa 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
index 86d4ffae2..173de7aac 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index f24270bb1..7b5ee968c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
- op->src[0]->ne[0] != 256) {
|
||||
- return false;
|
||||
- }
|
||||
- if (op->src[0]->ne[0] == 576) {
|
||||
- // DeepSeek sizes
|
||||
- // TODO: disabled for now, until optmized
|
||||
+ op->src[0]->ne[0] != 256 &&
|
||||
+ op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index e99c1763f..80864f303 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
- int32_t nsg = 4;
|
||||
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -14,7 +14,7 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
Status string `json:"-"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
blobs, err := BlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
status: fmt.Sprintf("%s %s", status, digest),
|
||||
Status: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
status: fmt.Sprintf("using existing layer %s", digest),
|
||||
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
blobPath, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
95
manifest/paths.go
Normal file
95
manifest/paths.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
func Path() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathForName returns the path to the manifest file for a specific model name.
|
||||
func PathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func BlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PruneDirectory removes empty directories recursively.
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageEditRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Image == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||
return
|
||||
}
|
||||
|
||||
genReq, err := openai.FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageEditsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
// Base64-encoded test image (1x1 pixel PNG)
|
||||
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
||||
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image edit basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing model",
|
||||
body: `{
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing image",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "image is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
// Wide version of KQ_C is column-major => swap A and B.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -39,6 +39,13 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
// validation after tensors have been loaded. If validation fails, model
|
||||
// loading will fail with the returned error.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
|
||||
if validator, ok := m.(Validator); ok {
|
||||
if err := validator.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
@@ -47,7 +50,9 @@ type Attention struct {
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
// MLA absorption: absorb K projection into query
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
@@ -217,7 +223,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
@@ -236,7 +241,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
@@ -279,6 +284,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
|
||||
return ErrOldModelFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
|
||||
73
model/models/glm4moelite/model_test.go
Normal file
73
model/models/glm4moelite/model_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid model with KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing KB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing both KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil Attention is ok",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: nil},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.model.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && err != ErrOldModelFormat {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -17,12 +18,34 @@ const (
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
// ministralEvent represents an event emitted during parsing
|
||||
type ministralEvent interface {
|
||||
isMinistralEvent()
|
||||
}
|
||||
|
||||
type ministralEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ministralEventThinking struct {
|
||||
thinking string
|
||||
}
|
||||
|
||||
type ministralEventToolCall struct {
|
||||
name string
|
||||
args string // raw JSON string
|
||||
}
|
||||
|
||||
func (ministralEventContent) isMinistralEvent() {}
|
||||
func (ministralEventThinking) isMinistralEvent() {}
|
||||
func (ministralEventToolCall) isMinistralEvent() {}
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
pendingToolName string // stores tool name while collecting args
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
const (
|
||||
ministralToolCallsTag = "[TOOL_CALLS]"
|
||||
ministralThinkTag = "[THINK]"
|
||||
ministralThinkEndTag = "[/THINK]"
|
||||
ministralArgsTag = "[ARGS]"
|
||||
)
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. The second return value indicates
|
||||
// whether to keep looping (true when state transitions, false when waiting
|
||||
// for more data).
|
||||
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
|
||||
var events []ministralEvent
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Check for [TOOL_CALLS] tag
|
||||
if strings.Contains(bufStr, ministralToolCallsTag) {
|
||||
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for [THINK] tag
|
||||
if strings.Contains(bufStr, ministralThinkTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkTag, 2)
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, ministralEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
|
||||
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
|
||||
overlapThink := overlap(bufStr, ministralThinkTag)
|
||||
maxOverlap := max(overlapToolCalls, overlapThink)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
// Withhold the potential partial tag
|
||||
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
|
||||
trailingWS := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWS
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit content but withhold trailing whitespace
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralThinkEndTag) {
|
||||
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
|
||||
thinkingContent := split[0]
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
p.buffer.WriteString(after)
|
||||
if len(thinkingContent) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: thinkingContent})
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
call := api.ToolCall{
|
||||
// Check for partial overlap with [/THINK]
|
||||
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// No tag found: emit all thinking content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, ministralEventThinking{thinking: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolName:
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
if strings.Contains(bufStr, ministralArgsTag) {
|
||||
split := strings.SplitN(bufStr, ministralArgsTag, 2)
|
||||
toolName := split[0]
|
||||
after := split[1]
|
||||
p.pendingToolName = toolName
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = ministralCollectingToolArgs
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
case ministralCollectingToolArgs:
|
||||
bufStr := p.buffer.String()
|
||||
jsonEnd := findJSONEnd(bufStr)
|
||||
|
||||
if jsonEnd != -1 {
|
||||
jsonStr := bufStr[:jsonEnd+1]
|
||||
remaining := bufStr[jsonEnd+1:]
|
||||
|
||||
events = append(events, ministralEventToolCall{
|
||||
name: p.pendingToolName,
|
||||
args: jsonStr,
|
||||
})
|
||||
|
||||
p.pendingToolName = ""
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = ministralCollectingContent
|
||||
return events, true
|
||||
}
|
||||
// Wait for more data
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unexpected ministral event")
|
||||
}
|
||||
}
|
||||
|
||||
// parseEvents loops calling eat() until it returns false
|
||||
func (p *MinistralParser) parseEvents() []ministralEvent {
|
||||
var all []ministralEvent
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []ministralEvent
|
||||
events, keepLooping = p.eat()
|
||||
all = append(all, events...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentBuilder, thinkingBuilder strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case ministralEventContent:
|
||||
contentBuilder.WriteString(e.content)
|
||||
case ministralEventThinking:
|
||||
thinkingBuilder.WriteString(e.thinking)
|
||||
case ministralEventToolCall:
|
||||
// Validate tool exists
|
||||
tool, toolErr := toolByName(p.tools, e.name)
|
||||
if toolErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
|
||||
}
|
||||
// Parse JSON arguments
|
||||
var args api.ToolCallFunctionArguments
|
||||
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Name: tool.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
})
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
// findJSONEnd finds the index of the closing brace that completes a JSON object.
|
||||
// It properly handles nested objects, arrays, and strings (including escaped characters).
|
||||
// Returns -1 if the JSON is not yet complete.
|
||||
func findJSONEnd(s string) int {
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range s {
|
||||
if inString {
|
||||
switch {
|
||||
case escaped:
|
||||
// If the previous character was a backslash, skip this character
|
||||
escaped = false
|
||||
case r == '\\':
|
||||
// Mark the next character as escaped
|
||||
escaped = true
|
||||
case r == '"':
|
||||
// End of string literal
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
// Start of string literal
|
||||
inString = true
|
||||
case '{', '[':
|
||||
// Increase nesting level for objects and arrays
|
||||
depth++
|
||||
case '}', ']':
|
||||
// Decrease nesting level
|
||||
depth--
|
||||
if depth == 0 {
|
||||
// Reached the end of the root JSON structure
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
545
model/parsers/ministral_test.go
Normal file
545
model/parsers/ministral_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMinistralParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []ministralEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
tools []api.Tool
|
||||
steps []step
|
||||
think bool // whether to enable thinking support
|
||||
}{
|
||||
// Content streaming
|
||||
{
|
||||
desc: "simple content",
|
||||
steps: []step{
|
||||
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "Hello, how can I help you?"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming content word by word",
|
||||
steps: []step{
|
||||
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
|
||||
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
|
||||
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Simple tool calls
|
||||
{
|
||||
desc: "simple tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with nested object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with deeply nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with array of objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with escaped quotes in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with braces inside string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty JSON object",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "no_args", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON with newlines in string",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "backslash in string value",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content after tool call
|
||||
{
|
||||
desc: "content after tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// NOTE: It's unclear if this is valid Ministral output, but the parser
|
||||
// currently treats text after a tool call as regular content. This test
|
||||
// documents that behavior so we notice if it changes.
|
||||
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": 1}`},
|
||||
ministralEventContent{content: "some content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Multiple tool calls
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "get_weather"}},
|
||||
{Function: api.ToolFunction{Name: "get_time"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
|
||||
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls streamed separately",
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "tool_a"}},
|
||||
{Function: api.ToolFunction{Name: "tool_b"}},
|
||||
},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
|
||||
}},
|
||||
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Streaming tool calls
|
||||
{
|
||||
desc: "streaming tool call with nested objects",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
|
||||
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
|
||||
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
|
||||
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
|
||||
{input: `]}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streaming with incomplete JSON waits for completion",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
|
||||
{input: `"a": {`, wantEvents: []ministralEvent{}},
|
||||
{input: `"b": 1`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{}},
|
||||
{input: `}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Partial tag handling
|
||||
{
|
||||
desc: "partial tool tag fakeout",
|
||||
steps: []step{
|
||||
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
|
||||
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tag split across chunks",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "[TOOL_", wantEvents: []ministralEvent{}},
|
||||
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
|
||||
steps: []step{
|
||||
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "hello"},
|
||||
ministralEventToolCall{name: "get_weather", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace between content and tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tabs and newlines before tool call are trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space before tool call is trimmed",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
|
||||
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace before THINK tag is trimmed",
|
||||
steps: []step{
|
||||
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing newline withheld then emitted",
|
||||
steps: []step{
|
||||
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
|
||||
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
|
||||
},
|
||||
},
|
||||
|
||||
// Thinking support
|
||||
{
|
||||
desc: "thinking content",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking here"},
|
||||
}},
|
||||
{input: "content after", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content after"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking with whitespace after end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "my thoughts"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non-breaking space after think end tag is trimmed",
|
||||
think: true,
|
||||
steps: []step{
|
||||
// \u00a0 is non-breaking space
|
||||
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "response"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial think end tag",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag fakeout",
|
||||
think: true,
|
||||
steps: []step{
|
||||
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
|
||||
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking then tool call",
|
||||
think: true,
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
|
||||
steps: []step{
|
||||
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
|
||||
ministralEventThinking{thinking: "let me think"},
|
||||
ministralEventToolCall{name: "test", args: `{}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Content then THINK tag transition
|
||||
{
|
||||
desc: "content then think tag",
|
||||
steps: []step{
|
||||
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "content"},
|
||||
ministralEventThinking{thinking: "thinking"},
|
||||
ministralEventContent{content: "more"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
||||
// Unicode handling
|
||||
{
|
||||
desc: "unicode content",
|
||||
steps: []step{
|
||||
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
|
||||
ministralEventContent{content: "你好 🌍 مرحبا"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "unicode in tool args",
|
||||
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
|
||||
steps: []step{
|
||||
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
|
||||
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := MinistralParser{}
|
||||
parser.hasThinkingSupport = tc.think
|
||||
parser.Init(tc.tools, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_Errors(t *testing.T) {
|
||||
t.Run("unknown tool returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown tool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns error", func(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
|
||||
|
||||
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindJSONEnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
input: `{"a": 1}`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
input: `{"a": {"b": 2}}`,
|
||||
expected: 14,
|
||||
},
|
||||
{
|
||||
name: "array inside object",
|
||||
input: `{"items": [1, 2, 3]}`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "braces in string",
|
||||
input: `{"template": "Hello {name}!"}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
input: `{"msg": "say \"hi\""}`,
|
||||
expected: 20,
|
||||
},
|
||||
{
|
||||
name: "incomplete object",
|
||||
input: `{"a": {"b": 1}`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
input: `{"a": {"b": {"c": {"d": 1}}}}`,
|
||||
expected: 28,
|
||||
},
|
||||
{
|
||||
name: "object with trailing content",
|
||||
input: `{"a": 1} extra`,
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
input: `[{"a": 1}, {"b": 2}]`,
|
||||
expected: 19,
|
||||
},
|
||||
{
|
||||
name: "escaped backslash before quote",
|
||||
input: `{"path": "C:\\"}`,
|
||||
expected: 15,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "no opening brace",
|
||||
input: "hello world",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "only opening brace",
|
||||
input: "{",
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "unclosed string",
|
||||
input: `{"key": "unclosed`,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "double escaped backslash then quote",
|
||||
input: `{"path": "C:\\\\"}`,
|
||||
expected: 17,
|
||||
},
|
||||
{
|
||||
name: "unicode in key and value",
|
||||
input: `{"키": "값"}`,
|
||||
expected: 13,
|
||||
},
|
||||
{
|
||||
name: "nested arrays",
|
||||
input: `{"matrix": [[1, 2], [3, 4]]}`,
|
||||
expected: 27,
|
||||
},
|
||||
{
|
||||
name: "mixed nesting",
|
||||
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
|
||||
expected: 31,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findJSONEnd(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasToolSupport(t *testing.T) {
|
||||
p := &MinistralParser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
|
||||
p := &MinistralParser{hasThinkingSupport: false}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
|
||||
p = &MinistralParser{hasThinkingSupport: true}
|
||||
if !p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return true")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package parsers
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
@@ -114,3 +115,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
// overlap returns the longest overlap between the suffix of s and the prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
|
||||
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image string `json:"image"` // Base64-encoded image data
|
||||
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
|
||||
// Decode the input image
|
||||
if r.Image != "" {
|
||||
imgData, err := decodeImageURL(r.Image)
|
||||
if err != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||
}
|
||||
req.Images = append(req.Images, imgData)
|
||||
}
|
||||
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "make it blue" {
|
||||
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if len(result.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Size: "512x768",
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Width != 512 {
|
||||
t.Errorf("expected width 512, got %d", result.Width)
|
||||
}
|
||||
|
||||
if result.Height != 768 {
|
||||
t.Errorf("expected height 768, got %d", result.Height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||
seed := int64(12345)
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Seed: &seed,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options == nil {
|
||||
t.Fatal("expected options to be set")
|
||||
}
|
||||
|
||||
if result.Options["seed"] != seed {
|
||||
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: "not-valid-base64",
|
||||
}
|
||||
|
||||
_, err := FromImageEditRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid image")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
var currentLineBuf []rune
|
||||
|
||||
// draining tracks if we're processing buffered input from cooked mode.
|
||||
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
|
||||
// We treat \n from cooked mode as submit, not multiline.
|
||||
// We check Buffered() after the first read since the bufio buffer is
|
||||
// empty until then. This is compatible with """ multiline mode in
|
||||
// interactive.go since each Readline() call is independent.
|
||||
var draining, stopDraining bool
|
||||
|
||||
for {
|
||||
// Apply deferred state change from previous iteration
|
||||
if stopDraining {
|
||||
draining = false
|
||||
stopDraining = false
|
||||
}
|
||||
|
||||
// don't show placeholder when pasting unless we're in multiline mode
|
||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||
if buf.IsEmpty() && showPlaceholder {
|
||||
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
r, err := i.Terminal.Read()
|
||||
|
||||
// After reading, check if there's more buffered data. If so, we're
|
||||
// processing cooked-mode input. Once buffer empties, the current
|
||||
// char is the last buffered one (still drain it), then stop next iteration.
|
||||
if i.Terminal.reader.Buffered() > 0 {
|
||||
draining = true
|
||||
} else if draining {
|
||||
stopDraining = true
|
||||
}
|
||||
|
||||
if buf.IsEmpty() {
|
||||
fmt.Print(ClearToEOL)
|
||||
}
|
||||
@@ -232,15 +255,20 @@ func (i *Instance) Readline() (string, error) {
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
// If not draining cooked-mode input, treat as multiline
|
||||
if !draining {
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
// Draining cooked-mode input: treat \n as submit
|
||||
fallthrough
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
VOL_NAME=${VOL_NAME:-"Ollama"}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -56,6 +56,12 @@ function checkEnv {
|
||||
|
||||
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
if (-not $env:CGO_CFLAGS) {
|
||||
$env:CGO_CFLAGS = "-O3"
|
||||
}
|
||||
if (-not $env:CGO_CXXFLAGS) {
|
||||
$env:CGO_CXXFLAGS = "-O3"
|
||||
}
|
||||
Write-Output "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := GetBlobsPath(files[fn])
|
||||
blobPath, err := manifest.BlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(t, mediaType)
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
var layers []manifest.Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := GetBlobsPath(layer.Digest)
|
||||
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
n model.Name
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
@@ -465,10 +467,10 @@ type downloadOpts struct {
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
fp, err := manifest.BlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
|
||||
211
server/images.go
211
server/images.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -75,12 +75,6 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
@@ -274,44 +268,22 @@ func (m *Model) String() string {
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
n := model.ParseName(name)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
m := &Model{
|
||||
Name: n.String(),
|
||||
ShortName: n.DisplayShortest(),
|
||||
Digest: mf.Digest(),
|
||||
Template: template.DefaultTemplate,
|
||||
}
|
||||
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if mf.Config.Digest != "" {
|
||||
filename, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -322,29 +294,29 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
for _, layer := range mf.Layers {
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -352,7 +324,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -362,7 +334,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.System = string(bts)
|
||||
m.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -371,7 +343,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -381,7 +353,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -389,11 +361,11 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.License = append(model.License, string(bts))
|
||||
m.License = append(m.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
@@ -408,7 +380,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -437,7 +409,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||
manifests, err := Manifests(true)
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -452,7 +424,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k := range deleteMap {
|
||||
fp, err := GetBlobsPath(k)
|
||||
fp, err := manifest.BlobsPath(k)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
@@ -468,7 +440,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
func PruneLayers() error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
p, err := GetBlobsPath("")
|
||||
p, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,9 +455,9 @@ func PruneLayers() error {
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := GetBlobsPath(name)
|
||||
_, err := manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||
@@ -510,63 +482,30 @@ func PruneLayers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
manifestPath, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -574,7 +513,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -582,17 +521,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -611,44 +550,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
manifest, _, err := GetManifest(mp)
|
||||
existingMf, err := manifest.ParseNamedManifest(n)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||
} else {
|
||||
for _, l := range manifest.Layers {
|
||||
for _, l := range existingMf.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
if manifest.Config.Digest != "" {
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -658,7 +597,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
n: n,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
@@ -677,7 +616,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
fp, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -692,16 +631,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
delete(deleteMap, mf.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -728,9 +667,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -738,7 +677,7 @@ func hasTensorLayers(layers []Layer) bool {
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -747,12 +686,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
destDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -784,7 +723,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
@@ -795,12 +734,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -812,7 +751,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -822,12 +761,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
srcDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -864,13 +803,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
ManifestRef: n.Tag,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
@@ -880,7 +819,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var m Manifest
|
||||
var m manifest.Manifest
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1042,7 +981,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
fp, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with image and vision capability (image editing)",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image", "vision"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -20,19 +21,19 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := ParseNamedManifest(name)
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = ParseNamedManifest(name)
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultRegistry = "registry.ollama.ai"
|
||||
DefaultNamespace = "library"
|
||||
DefaultTag = "latest"
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
} else if strings.Contains(name, "attn_k_b.weight") ||
|
||||
strings.Contains(name, "attn_v_b.weight") ||
|
||||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
|
||||
strings.Contains(name, "attn_q_a.weight") ||
|
||||
strings.Contains(name, "attn_q_b.weight") {
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
n_layer := qs.nFfnDown
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
existing, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
m, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, ErrModelPathInvalid
|
||||
return nil, model.Unqualified(name)
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1147,7 +1148,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
ModifiedAt: mf.FileInfo().ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||
// default we return an empty map.
|
||||
@@ -1214,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1223,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1285,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
ms, err := Manifests(true)
|
||||
ms, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1316,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Digest: m.Digest(),
|
||||
ModifiedAt: m.FileInfo().ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1376,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1392,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||
p, err := GetBlobsPath(ib)
|
||||
p, err := manifest.BlobsPath(ib)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1410,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1428,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1603,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -1628,7 +1630,7 @@ func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
||||
blobsDir, err := GetBlobsPath("")
|
||||
blobsDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1637,7 +1639,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() {
|
||||
if _, err := Manifests(false); err != nil {
|
||||
if _, err := manifest.Manifests(false); err != nil {
|
||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||
} else {
|
||||
// clean up unused layers and manifests
|
||||
@@ -1645,12 +1647,12 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := GetManifestPath()
|
||||
manifestsPath, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -2506,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
// Check streaming preference
|
||||
isStreaming := req.Stream == nil || *req.Stream
|
||||
|
||||
contentType := "application/x-ndjson"
|
||||
if !isStreaming {
|
||||
contentType = "application/json; charset=utf-8"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
@@ -2522,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
}
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i, imgData := range req.Images {
|
||||
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
var finalResponse api.GenerateResponse
|
||||
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
@@ -2552,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
finalResponse = res
|
||||
return
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
@@ -2561,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
c.JSON(http.StatusOK, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if manifest.Config.Digest == "" {
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func (mockRunner) Ping(_ context.Context) error { return nil }
|
||||
|
||||
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateWithImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("images passed to completion request", func(t *testing.T) {
|
||||
testImage := []byte("test-image-data")
|
||||
|
||||
mock.CompletionResponse.Content = "Image processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{testImage},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify images were passed to the completion request
|
||||
if len(mock.CompletionRequest.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||
t.Errorf("image data mismatch in completion request")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||
testImage1 := []byte("test-image-1")
|
||||
testImage2 := []byte("test-image-2")
|
||||
|
||||
mock.CompletionResponse.Content = "Images processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Compare these images",
|
||||
Images: []api.ImageData{testImage1, testImage2},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify both images were passed
|
||||
if len(mock.CompletionRequest.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||
t.Errorf("first image data mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||
t.Errorf("second image data mismatch")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no images when none provided", func(t *testing.T) {
|
||||
mock.CompletionResponse.Content = "No images"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify no images in completion request
|
||||
if len(mock.CompletionRequest.Images) != 0 {
|
||||
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestImageGenerateStreamFalse tests that image generation respects stream=false
|
||||
// and returns a single JSON response instead of streaming ndjson.
|
||||
func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
Prompt: "test prompt",
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
|
||||
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Image != "base64image" {
|
||||
t.Errorf("expected image 'base64image', got %q", resp.Image)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected done=true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,12 +21,14 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -51,7 +53,7 @@ const (
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
|
||||
const MissingPart = "!MISSING!"
|
||||
|
||||
const (
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
// DefaultName returns a name with the default values for the host, namespace,
|
||||
// and tag parts. The model and digest parts are empty.
|
||||
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||
//
|
||||
// - The default host is ("registry.ollama.ai")
|
||||
// - The default namespace is ("library")
|
||||
// - The default tag is ("latest")
|
||||
// - The default protocol scheme is ("https")
|
||||
func DefaultName() Name {
|
||||
return Name{
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
ProtocolScheme: defaultProtocolScheme,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +91,11 @@ func (k partKind) String() string {
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
ProtocolScheme string
|
||||
}
|
||||
|
||||
// ParseName parses and assembles a Name from a name string. The
|
||||
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
|
||||
}
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
if ok {
|
||||
n.ProtocolScheme = scheme
|
||||
} else {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
// BaseURL returns the base URL for the registry.
|
||||
func (n Name) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: n.ProtocolScheme,
|
||||
Host: n.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||
func (n Name) DisplayNamespaceModel() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
b.WriteString(n.Model)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
||||
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
ProtocolScheme: "scheme",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
@@ -103,7 +105,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
layer, err := manifest.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
@@ -141,13 +143,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,7 +171,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,7 +188,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,10 +211,23 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// TODO: find a better way to detect image input support
|
||||
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
|
||||
caps := capabilities
|
||||
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
|
||||
if data, err := os.ReadFile(modelIndex); err == nil {
|
||||
var cfg struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
|
||||
caps = append(caps, "vision")
|
||||
}
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
@@ -221,15 +236,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
// Convert LayerInfo to manifest.Layer
|
||||
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
@@ -243,19 +258,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
@@ -263,7 +278,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
@@ -271,7 +286,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
|
||||
6
x/imagegen/cache/step.go
vendored
6
x/imagegen/cache/step.go
vendored
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// Supports both single-stream and dual-stream architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
|
||||
@@ -10,7 +10,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
// Image paths can be included in the prompt and will be extracted automatically.
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract any image paths from the prompt
|
||||
prompt, images, err := extractFileData(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
printCurrentSettings(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
||||
// Check if it's a file path, not a command
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isFile {
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Extract any image paths from the input
|
||||
prompt, images, err := extractFileData(line)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractFileNames finds image file paths in the input string.
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths with image extensions
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
// extractFileData extracts image data from file paths found in the input.
|
||||
// Returns the cleaned prompt (with file paths removed) and the image data.
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
// Normalize shell escapes
|
||||
nfp := strings.ReplaceAll(fp, "\\ ", " ")
|
||||
nfp = strings.ReplaceAll(nfp, "\\(", "(")
|
||||
nfp = strings.ReplaceAll(nfp, "\\)", ")")
|
||||
nfp = strings.ReplaceAll(nfp, "%20", " ")
|
||||
|
||||
data, err := getImageData(nfp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return strings.TrimSpace(input), imgs, nil
|
||||
}
|
||||
|
||||
// getImageData reads and validates image data from a file.
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
// Re-read the full file
|
||||
return os.ReadFile(filePath)
|
||||
}
|
||||
|
||||
@@ -21,8 +21,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -61,14 +59,11 @@ func main() {
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
@@ -166,60 +161,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
// Transparent images are composited onto a white background.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
img = flattenAlpha(img)
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// flattenAlpha composites an image onto a white background,
|
||||
// removing any transparency. This is needed because image
|
||||
// generation models don't handle alpha channels well.
|
||||
func flattenAlpha(img image.Image) image.Image {
|
||||
if _, ok := img.(*image.RGBA); !ok {
|
||||
if _, ok := img.(*image.NRGBA); !ok {
|
||||
// No alpha channel, return as-is
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(bounds)
|
||||
|
||||
// Fill with white background
|
||||
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
|
||||
|
||||
// Composite the image on top
|
||||
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
|
||||
@@ -33,6 +33,17 @@ type ModelManifest struct {
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
func DefaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the manifest storage directory.
|
||||
// Respects OLLAMA_MODELS.
|
||||
|
||||
func DefaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
||||
func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
@@ -50,7 +61,7 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
|
||||
return &ModelManifest{
|
||||
Manifest: &manifest,
|
||||
BlobDir: filepath.Join(envconfig.Models(), "blobs"),
|
||||
BlobDir: DefaultBlobDir(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -81,7 +92,7 @@ func resolveManifestPath(modelName string) string {
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(envconfig.Models(), "manifests", host, namespace, name, tag)
|
||||
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// BlobPath returns the full path to a blob given its digest.
|
||||
@@ -150,6 +161,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TotalTensorSize returns the total size in bytes of all tensor layers.
|
||||
func (m *ModelManifest) TotalTensorSize() int64 {
|
||||
var total int64
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
total += layer.Size
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
|
||||
@@ -3,27 +3,55 @@ package imagegen
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestLoadManifestRespectsOLLAMA_MODELS(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", "/custom/models/path")
|
||||
|
||||
// Verify envconfig.Models() returns our custom path
|
||||
if got := envconfig.Models(); got != "/custom/models/path" {
|
||||
t.Fatalf("envconfig.Models() = %q, want %q", got, "/custom/models/path")
|
||||
func TestTotalTensorSize(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
|
||||
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// LoadManifest will fail (no manifest exists), but we can verify
|
||||
// the error message contains the custom path
|
||||
_, err := LoadManifest("test-model")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
expectedPath := filepath.Join("/custom/models/path", "manifests", "registry.ollama.ai", "library", "test-model", "latest")
|
||||
if got := err.Error(); got != "read manifest: open "+expectedPath+": no such file or directory" {
|
||||
t.Errorf("error = %q, want path to contain %q", got, expectedPath)
|
||||
got := m.TotalTensorSize()
|
||||
want := int64(6000)
|
||||
if got != want {
|
||||
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTotalTensorSizeEmpty(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{},
|
||||
},
|
||||
}
|
||||
|
||||
if got := m.TotalTensorSize(); got != 0 {
|
||||
t.Errorf("TotalTensorSize() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
// Simulate packaged/systemd environment
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
t.Setenv("HOME", "/usr/share/ollama")
|
||||
|
||||
// Manifest dir must respect OLLAMA_MODELS
|
||||
wantManifest := filepath.Join(modelsDir, "manifests")
|
||||
if got := DefaultManifestDir(); got != wantManifest {
|
||||
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
|
||||
}
|
||||
|
||||
// Blob dir must respect OLLAMA_MODELS
|
||||
wantBlobs := filepath.Join(modelsDir, "blobs")
|
||||
if got := DefaultBlobDir(); got != wantBlobs {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,18 +16,9 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GB is a convenience constant for gigabytes.
|
||||
const GB = 1024 * 1024 * 1024
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
// Returns nil if supported, or an error describing why it's not supported.
|
||||
func CheckPlatformSupport() error {
|
||||
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
||||
// Returns nil if memory is sufficient, or an error if not.
|
||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
||||
required := EstimateVRAM(modelName)
|
||||
if availableMemory < required {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
required/GB, availableMemory/GB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
|
||||
@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMemoryRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
availableMemory uint64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sufficient memory",
|
||||
availableMemory: 32 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exactly enough memory",
|
||||
availableMemory: 21 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insufficient memory",
|
||||
availableMemory: 16 * GB,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory",
|
||||
availableMemory: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use a non-existent model name which will default to 21GB estimate
|
||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
||||
t.Errorf("Missing VRAM estimate for %s", name)
|
||||
} else if actual != expectedVRAM {
|
||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateVRAMDefault(t *testing.T) {
|
||||
// Non-existent model should return default 21GB
|
||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
||||
if vram != 21*GB {
|
||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||
// It generates an image conditioned on the provided input images for image editing.
|
||||
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
InputImages: inputImages,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,367 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,868 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user