Compare commits

..

29 Commits

Author SHA1 Message Date
Gabe Goodhart
7b62c41060 cmd/config: use envconfig.Host() for base API in launch config packages (#13937) 2026-01-27 13:30:00 -08:00
Parth Sareen
26acab64b7 docs: add clawdbot (#13925) 2026-01-26 18:32:54 -08:00
Gyungrai Wang
e0f03790b1 parsers/ministral: fix nested tool call parsing by counting brace nesting (#13905)
* parsers/ministral: fix nested tool call parsing by counting brace nesting

* fix lint error

* parsers: refactor ministral parser

The old one was very tied to expecting to see only one token at a time,
which I don't like to assume (who knows what the future might hold wrt
speculative decoding, etc). This new one follows a similar structure to
qwen3-coder's parser, which incidentally makes it easier to test as well
(since we can test the individual events that come out when given
particular inputs).

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-26 15:03:43 -08:00
Parth Sareen
3ab842b0f5 cmd: clawdbot config fixes (#13922) 2026-01-26 14:34:29 -08:00
Parth Sareen
b8e8ef8929 cmd: ollama launch clawdbot (#13921) 2026-01-26 13:40:59 -08:00
Parth Sareen
465d124183 cmd: fix opencode config (#13894) 2026-01-24 18:42:56 -08:00
Parth Sareen
d310e56fa3 cmd: add fallback for claude (#13892) 2026-01-24 18:26:01 -08:00
Jeffrey Morgan
a1ca428c90 glm4moelite: fix attention scale calculation (#13893)
Use the original key dimension (qkNopeHeadDim + qkRopeHeadDim = 256) for
the attention scale instead of the MLA absorbed dimension (kvLoraRank +
qkRopeHeadDim = 576).

MLA absorption is a mathematically equivalent reorganization of the
attention computation - it should not change the effective attention
scale. The scale should match training, which uses 1/sqrt(256).

This improves tool calling and model looping issues.
2026-01-24 17:48:09 -08:00
Jeffrey Morgan
16750865d1 glm4moelite: quantize more tensors to q8_0 and avoid double BOS token (#13891) 2026-01-24 16:33:54 -08:00
Jeffrey Morgan
f3b476c592 build: add -O3 optimization to CGO flags (#13877)
CGO_CFLAGS and CGO_CXXFLAGS were being set without optimization flags,
which overrides Go's default -O2 and results in unoptimized C++ code.

This caused significant performance degradation in release builds
compared to local `go build` which uses the default optimization.

- build_darwin.sh: add -O3 to CGO_CFLAGS and CGO_CXXFLAGS exports
- Dockerfile: preserve CGO_CFLAGS/CGO_CXXFLAGS from build args instead
  of overwriting them
- app/README.md: update documentation to include -O3
2026-01-24 10:55:38 -08:00
Parth Sareen
5267d31d56 docs: ollama launch (#13852) 2026-01-23 23:18:50 -08:00
Stillhart
b44f56319f README: Update the "Ollama for ruby" to the most popular and maintained ruby gem. (#13855)
* update README ruby link

the ollama-ai ruby gem is vastly less popular and seems unmaintained
https://rubygems.org/gems/ollama-ai

the defacto standard with the most downloads in the ruby ecosystem is ruby_llm
https://rubygems.org/gems/ruby_llm

I would link to that to avoid complication and guarantee feature compatibility with ollama.

* Update gem link ruby_llm from website to GitHub

ollama links mostly to github, not project websites, hence link to ruby_llm github.
2026-01-24 01:24:52 -05:00
Jeffrey Morgan
0209c268bb llama: fix CUDA MMA errors in release build (#13874) 2026-01-23 20:10:04 -08:00
Jeffrey Morgan
912d984346 llama: fix fattn-tile shared memory overflow on sm_50/52 (#13872)
Use nthreads=128 for ncols=4 configurations in flash attention tile
kernel to reduce shared memory usage below 48KB limit on Maxwell
architectures (sm_50/52).

With nthreads=256 and ncols=4, np=2 which caused shared memory to
exceed 48KB. With nthreads=128 and ncols=4, np=1 keeps shared memory
under the limit.
2026-01-23 19:22:32 -08:00
Parth Sareen
aae6ecbaff cmd: rename ollama config to ollama launch (#13871) 2026-01-23 18:40:40 -08:00
Jeffrey Morgan
64737330a4 Re-apply "model: add MLA absorption for glm4moelite" with fix (#13870)
The nvidia_fp32 config for (576, 512) head sizes had nbatch_fa=32,
which caused zero-sized arrays when computing array dimensions:
  nbatch_fa / (np * warp_size) = 32 / (2 * 32) = 0

This resulted in CUDA compilation failures on CUDA 12 (Windows and
Linux arm64):
- "static assertion failed with nbatch_fa % (np*warp_size) != 0"
- "the size of an array must be greater than zero"

Fix by changing nbatch_fa from 32 to 64 for all (576, 512) configs
in the nvidia_fp32 function, matching the nvidia_fp16 and AMD configs.
2026-01-23 18:40:28 -08:00
Jeffrey Morgan
2eda97f1c3 Revert "model: add MLA absorption for glm4moelite (#13810)" (#13869)
This reverts commit 1044b0419a.
2026-01-23 17:14:15 -08:00
Jeffrey Morgan
66831dcf70 x/imagegen: fix image editing support (#13866)
- Fix panic in ollama show for image gen models (safe type assertion)
- Add vision capability for Flux2KleinPipeline models at create time
- Flatten transparent PNG images onto white background for better results
2026-01-23 15:37:17 -08:00
Jeffrey Morgan
1044b0419a model: add MLA absorption for glm4moelite (#13810)
* model: add MLA absorption for glm4moelite

Split the combined KV_B tensor into separate K_B and V_B tensors
during conversion, enabling MLA (Multi-head Latent Attention)
absorption which compresses the KV cache for improved efficiency.

* ggml: enable MLA flash attention for GLM-4.7-flash

Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).

Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups

CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4

* model: add compatibility validation for glm4moelite architecture
2026-01-23 14:47:42 -08:00
Parth Sareen
771d9280ec cmd: ollama config fix droid model name configuration (#13856) 2026-01-23 11:44:22 -08:00
Jeffrey Morgan
862bc0a3bf x/imagegen: respect stream=false in /api/generate (#13853)
When stream=false is set for image generation requests, return a single
JSON response instead of streaming multiple ndjson progress updates.
2026-01-22 22:16:39 -08:00
Jeffrey Morgan
c01608b6a1 x/imagegen: add image edit capabilities (#13846) 2026-01-22 20:35:08 -08:00
Parth Sareen
199c41e16e cmd: ollama config command to help configure integrations to use Ollama (#13712) 2026-01-22 20:17:11 -08:00
Jeffrey Morgan
3b3bf6c217 x/imagegen: replace memory estimation with actual weight size (#13848)
Remove static VRAM estimation (EstimateVRAM, CheckMemoryRequirements)
which wasn't helpful. Instead, report the actual tensor weight size
from the manifest for ollama ps.

- Remove memory estimation check from runner startup
- Remove EstimateVRAM, CheckMemoryRequirements, modelVRAMEstimates
- Add TotalTensorSize() to get actual weight size from manifest
- Use weight size for Server.vramSize instead of estimates

Note: This is better than showing 0 or inaccurate estimates, but the
weight size is a drastic underestimation of actual memory usage since
it doesn't account for activations, intermediate tensors, or MLX
overhead. Future work should query real-time memory from MLX
(e.g., MetalGetActiveMemory) for accurate reporting.
2026-01-22 18:32:41 -08:00
Parth Sareen
f52c21f457 fix: handle Enter key pressed during model loading (#13839) 2026-01-22 18:32:02 -08:00
Jeffrey Morgan
b5d0f72f16 x/imagegen: remove qwen_image and qwen_image_edit models (#13827)
Remove the Qwen image generation and image editing model packages
to clean up the codebase. These models will be reintroduced later.

- Delete x/imagegen/models/qwen_image/ (10 files)
- Delete x/imagegen/models/qwen_image_edit/ (5 files)
- Remove related CLI flags and imports from cmd/engine/main.go
- Update comments in cache/step.go to remove Qwen-specific references
2026-01-21 13:37:08 -08:00
Patrick Devine
148a1be0a3 Clean up the manifest and modelpath (#13807) 2026-01-21 11:46:17 -08:00
next-n
d6dd430abd x/imagegen: respect OLLAMA_MODELS for manifests and blobs (#13797) 2026-01-20 13:01:52 -08:00
Daniel Hiltgen
ae78112c50 test: add lfm2.5-thinking coverage (#13802) 2026-01-20 12:57:02 -08:00
113 changed files with 9476 additions and 8448 deletions

View File

@@ -169,8 +169,10 @@ COPY . .
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .

View File

@@ -558,7 +558,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)

View File

@@ -75,9 +75,9 @@ The `-dev` flag enables:
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
```
export CGO_CFLAGS=-mmacosx-version-min=12.0
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
export CGO_LDFLAGS=-mmacosx-version-min=12.0
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
```

View File

@@ -35,6 +35,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -1018,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
arch, _ := resp.ModelInfo["general.architecture"].(string)
if arch != "" {
rows = append(rows, []string{"", "architecture", arch})
}
var paramStr string
if resp.Details.ParameterSize != "" {
@@ -1029,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if paramStr != "" {
rows = append(rows, []string{"", "parameters", paramStr})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
@@ -2026,6 +2031,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
)
return rootCmd

60
cmd/config/claude.go Normal file
View File

@@ -0,0 +1,60 @@
package config
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
if model != "" {
return []string{"--model", model}
}
return nil
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
return cmd.Run()
}

101
cmd/config/claude_test.go Normal file
View File

@@ -0,0 +1,101 @@
package config
import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
)
func TestClaudeIntegration(t *testing.T) {
c := &Claude{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Claude Code" {
t.Errorf("String() = %q, want %q", got, "Claude Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestClaudeFindPath(t *testing.T) {
c := &Claude{}
t.Run("finds claude in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(tmpDir, ".claude", "local", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

195
cmd/config/clawdbot.go Normal file
View File

@@ -0,0 +1,195 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
type Clawdbot struct{}
func (c *Clawdbot) String() string { return "Clawdbot" }
const ansiGreen = "\033[32m"
func (c *Clawdbot) Run(model string) error {
if _, err := exec.LookPath("clawdbot"); err != nil {
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
}
models := []string{model}
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("clawdbot", "gateway")
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
func (c *Clawdbot) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Clawdbot) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Clawdbot) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

625
cmd/config/clawdbot_test.go Normal file
View File

@@ -0,0 +1,625 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestClawdbotIntegration(t *testing.T) {
c := &Clawdbot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Clawdbot" {
t.Errorf("String() = %q, want %q", got, "Clawdbot")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClawdbotEdit(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelExists(t, configPath, "mistral")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
})
}
func TestClawdbotModels(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertClawdbotModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestClawdbotPaths(t *testing.T) {
c := &Clawdbot{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestClawdbotModelsEdgeCases(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestClawdbotEditSchemaFields(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestClawdbotEditModelNames(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "library/model:tag")
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "test-model")
})
}
func TestClawdbotEditAgentsPreservation(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testClawdbotFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestClawdbotEdit_RoundTrip(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestClawdbotEdit_Idempotent(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestClawdbotEdit_BackupCreated(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}

61
cmd/config/codex.go Normal file
View File

@@ -0,0 +1,61 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
return args
}
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

28
cmd/config/codex_test.go Normal file
View File

@@ -0,0 +1,28 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

115
cmd/config/config.go Normal file
View File

@@ -0,0 +1,115 @@
// Package config provides integration configuration for external coding tools
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
}
return nil, err
}
var cfg config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
}
if cfg.Integrations == nil {
cfg.Integrations = make(map[string]*integration)
}
return &cfg, nil
}
func save(cfg *config) error {
path, err := configPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return writeWithBackup(path, data)
}
func saveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
}
return result, nil
}

373
cmd/config/config_test.go Normal file
View File

@@ -0,0 +1,373 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if len(config.Models) != len(models) {
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
}
for i, m := range models {
if config.Models[i] != m {
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
}
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-a" {
t.Errorf("expected model-a, got %s", defaultModel)
}
})
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
config := &integration{Models: []string{}}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "" {
t.Errorf("expected empty string, got %s", defaultModel)
}
})
t.Run("app name is case-insensitive", func(t *testing.T) {
saveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-x" {
t.Errorf("expected model-x, got %s", defaultModel)
}
})
t.Run("multiple integrations in single file", func(t *testing.T) {
saveIntegration("app1", []string{"model-1"})
saveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
defaultModel1 = config1.Models[0]
}
defaultModel2 := ""
if len(config2.Models) > 0 {
defaultModel2 = config2.Models[0]
}
if defaultModel1 != "model-1" {
t.Errorf("expected model-1, got %s", defaultModel1)
}
if defaultModel2 != "model-2" {
t.Errorf("expected model-2, got %s", defaultModel2)
}
})
}
func TestListIntegrations(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty when no integrations", func(t *testing.T) {
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 0 {
t.Errorf("expected 0 integrations, got %d", len(configs))
}
})
t.Run("returns all saved integrations", func(t *testing.T) {
saveIntegration("claude", []string{"model-1"})
saveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 2 {
t.Errorf("expected 2 integrations, got %d", len(configs))
}
})
}
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
}
func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := loadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
if config.Models == nil {
// nil is acceptable
} else if len(config.Models) != 0 {
t.Errorf("expected empty or nil models, got %v", config.Models)
}
}
func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := saveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
}
}
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}
if !os.IsNotExist(err) {
t.Logf("error type is os.ErrNotExist as expected: %v", err)
}
}
func TestConfigPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
path, err := configPath()
if err != nil {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
}
func TestLoad(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty config when file does not exist", func(t *testing.T) {
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg == nil {
t.Fatal("expected non-nil config")
}
if cfg.Integrations == nil {
t.Error("expected non-nil Integrations map")
}
if len(cfg.Integrations) != 0 {
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
}
})
t.Run("loads existing config", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["test"] == nil {
t.Fatal("expected test integration")
}
if len(cfg.Integrations["test"].Models) != 1 {
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
}
})
t.Run("returns error for corrupted JSON", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{corrupted`), 0o644)
_, err := load()
if err == nil {
t.Error("expected error for corrupted JSON")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("creates config file", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"test": {Models: []string{"model-a", "model-b"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
path, _ := configPath()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("config file was not created")
}
})
t.Run("round-trip preserves data", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"claude": {Models: []string{"llama3.2", "mistral"}},
"codex": {Models: []string{"qwen2.5"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
loaded, err := load()
if err != nil {
t.Fatal(err)
}
if len(loaded.Integrations) != 2 {
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
}
if loaded.Integrations["claude"] == nil {
t.Error("missing claude integration")
}
if len(loaded.Integrations["claude"].Models) != 2 {
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
}
})
}

186
cmd/config/droid.go Normal file
View File

@@ -0,0 +1,186 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidSettings represents the Droid settings.json file (only fields we use)
type droidSettings struct {
CustomModels []modelEntry `json:"customModels"`
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
}
type sessionSettings struct {
Model string `json:"model"`
ReasoningEffort string `json:"reasoningEffort"`
}
type modelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Provider string `json:"provider"`
MaxOutputTokens int `json:"maxOutputTokens"`
SupportsImages bool `json:"supportsImages"`
ID string `json:"id"`
Index int `json:"index"`
}
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (d *Droid) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".factory", "settings.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (d *Droid) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
settingsPath := filepath.Join(home, ".factory", "settings.json")
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
return err
}
// Read file once, unmarshal twice:
// map preserves unknown fields for writing back (including extra fields in model entries)
settingsMap := make(map[string]any)
var settings droidSettings
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settingsMap); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
if rawModels, ok := settingsMap["customModels"].([]any); ok {
for _, raw := range rawModels {
if m, ok := raw.(map[string]any); ok {
if m["apiKey"] != "ollama" {
nonOllamaModels = append(nonOllamaModels, raw)
}
}
}
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var newModels []any
var defaultModelID string
for i, model := range models {
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,
})
if i == 0 {
defaultModelID = modelID
}
}
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
// Update session default settings (preserve unknown fields in the nested object)
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
sessionSettings["reasoningEffort"] = "none"
}
settingsMap["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
}
func (d *Droid) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
var settings droidSettings
if err := json.Unmarshal(data, &settings); err != nil {
return nil
}
var result []string
for _, m := range settings.CustomModels {
if m.APIKey == "ollama" {
result = append(result, m.Model)
}
}
return result
}
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}

1302
cmd/config/droid_test.go Normal file
View File

File diff suppressed because it is too large Load Diff

99
cmd/config/files.go Normal file
View File

@@ -0,0 +1,99 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
func readJSONFile(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
func copyFile(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, info.Mode().Perm())
}
func backupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups")
}
func backupToTmp(srcPath string) (string, error) {
dir := backupDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
if err := copyFile(srcPath, backupPath); err != nil {
return "", err
}
return backupPath, nil
}
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func writeWithBackup(path string, data []byte) error {
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {
if !bytes.Equal(existingContent, data) {
backupPath, err = backupToTmp(path)
if err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("create temp failed: %w", err)
}
tmpPath := tmp.Name()
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("write failed: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("sync failed: %w", err)
}
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("close failed: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
_ = os.Remove(tmpPath)
if backupPath != "" {
_ = copyFile(backupPath, path)
}
return fmt.Errorf("rename failed: %w", err)
}
return nil
}

502
cmd/config/files_test.go Normal file
View File

@@ -0,0 +1,502 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return data
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir()
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var result map[string]string
if err := json.Unmarshal(content, &result); err != nil {
t.Fatal(err)
}
if result["key"] != "value" {
t.Errorf("expected value, got %s", result["key"])
}
})
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(backupDir())
if err != nil {
t.Fatal("backup directory not created")
}
var foundBackup bool
for _, entry := range entries {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
json.Unmarshal(backup, &backupData)
if backupData["original"] {
foundBackup = true
os.Remove(backupPath)
break
}
}
}
}
}
if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups")
}
current, _ := os.ReadFile(path)
var currentData map[string]bool
json.Unmarshal(current, &currentData)
if !currentData["updated"] {
t.Error("file doesn't contain updated data")
}
})
t.Run("no backup for new file", func(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
}
}
})
t.Run("no backup when content unchanged", func(t *testing.T) {
path := filepath.Join(tmpDir, "unchanged.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(backupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countBefore++
}
}
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(backupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countAfter++
}
}
if countAfter != countBefore {
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
}
})
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
path := filepath.Join(tmpDir, "timestamped.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
timestamp := name[len("timestamped.json."):]
for _, c := range timestamp {
if c < '0' || c > '9' {
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
}
}
found = true
os.Remove(filepath.Join(backupDir(), name))
break
}
}
if !found {
t.Error("backup file with timestamp not found")
}
})
}
// Edge case tests for files.go
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
// User could lose their config with no way to recover.
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
// Create original file
originalContent := []byte(`{"original": true}`)
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := backupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
t.Error("expected error when backup fails, got nil")
}
// Original file should be preserved
current, _ := os.ReadFile(path)
if string(current) != string(originalContent) {
t.Errorf("original file was modified despite backup failure: got %s", string(current))
}
}
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
// Common issue when config owned by root or wrong perms.
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
os.MkdirAll(readOnlyDir, 0o755)
os.Chmod(readOnlyDir, 0o444)
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
}
}
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
// Documents what happens if user symlinks their config file.
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := t.TempDir()
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
// Create real file and symlink
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
os.Symlink(realFile, symlink)
// Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
// The real file should be updated (symlink followed for temp file creation)
content, _ := os.ReadFile(symlink)
if string(content) != `{"v": 2}` {
t.Errorf("symlink target not updated correctly: got %s", string(content))
}
}
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir()
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
backupPath, err := backupToTmp(path)
if err != nil {
t.Fatalf("backupToTmp with special chars failed: %v", err)
}
// Verify backup exists and has correct content
content, err := os.ReadFile(backupPath)
if err != nil {
t.Fatalf("could not read backup: %v", err)
}
if string(content) != `{"test": true}` {
t.Errorf("backup content mismatch: got %s", string(content))
}
os.Remove(backupPath)
}
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
func TestCopyFile_PreservesPermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
// Create source with specific permissions
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile failed: %v", err)
}
srcInfo, _ := os.Stat(src)
dstInfo, _ := os.Stat(dst)
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
}
}
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
err := copyFile(src, dst)
if err == nil {
t.Error("expected error for nonexistent source, got nil")
}
}
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
}
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("could not read file: %v", err)
}
if len(content) != 0 {
t.Errorf("expected empty file, got %d bytes", len(content))
}
}
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
// cannot be read (for backup comparison) but directory is writable.
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
os.Chmod(path, 0o000)
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
}
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
// Verify final content
content, _ := os.ReadFile(path)
if string(content) != `{"v": 3}` {
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
}
// Verify at least one backup exists
entries, _ := os.ReadDir(backupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
backupCount++
}
}
if backupCount == 0 {
t.Error("expected at least one backup file from rapid writes")
}
}
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test modifies system temp directory")
}
// Create a file at the backup directory path
backupPath := backupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
defer func() {
os.Remove(backupPath)
os.MkdirAll(backupPath, 0o755)
}()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
}
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Count existing temp files
countTempFiles := func() int {
entries, _ := os.ReadDir(tmpDir)
count := 0
for _, e := range entries {
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
count++
}
}
return count
}
before := countTempFiles()
// Create a file, then make directory read-only to cause rename failure
path := filepath.Join(tmpDir, "orphan.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
// Make a subdirectory and try to write there after making parent read-only
subDir := filepath.Join(tmpDir, "subdir")
os.MkdirAll(subDir, 0o755)
subPath := filepath.Join(subDir, "config.json")
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
// Make subdir read-only after creating temp file would succeed but rename would fail
// This is tricky to test - the temp file is created in the same dir, so if we can't
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
// Force a failure by making the target a directory
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
}
}

355
cmd/config/integrations.go Normal file
View File

@@ -0,0 +1,355 @@
package config
import (
"context"
"errors"
"fmt"
"maps"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
// Runners execute the launching of a model with the integration - claude, codex
// Editors can edit config files (supports multi-model selection) - opencode, droid
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
// Editor can edit config files (supports multi-model selection)
type Editor interface {
// Paths returns the paths to the config files for the integration
Paths() []string
// Edit updates the config files for the integration with the given models
Edit(models []string) error
// Models returns the models currently configured for the integration
// TODO(parthsareen): add error return to Models()
Models() []string
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Clawdbot{},
"codex": &Codex{},
"droid": &Droid{},
"opencode": &OpenCode{},
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
}
items = append(items, selectItem{Name: name, Description: description})
}
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported integrations:
claude Claude Code
clawdbot Clawdbot
codex Codex
droid Droid
opencode OpenCode
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
var name string
if len(args) > 0 {
name = args[0]
} else {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0])
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}

View File

@@ -0,0 +1,188 @@
package config
import (
"slices"
"strings"
"testing"
"github.com/spf13/cobra"
)
func TestIntegrationLookup(t *testing.T) {
tests := []struct {
name string
input string
wantFound bool
wantName string
}{
{"claude lowercase", "claude", true, "Claude Code"},
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
{"empty string", "", false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, found := integrations[strings.ToLower(tt.input)]
if found != tt.wantFound {
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
}
if found && r.String() != tt.wantName {
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
}
})
}
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
r, ok := integrations[name]
if !ok {
t.Fatalf("integration %q not found in registry", name)
}
if r.String() == "" {
t.Error("integration.String() should not be empty")
}
})
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string
models []string
want bool
}{
{"empty list", []string{}, false},
{"single local model", []string{"llama3.2"}, true},
{"single cloud model", []string{"cloud-model"}, false},
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
{"local model first", []string{"llama3.2", "cloud-model"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
}
})
}
}
func TestLaunchCmd(t *testing.T) {
// Mock checkServerHeartbeat that always succeeds
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
modelFlag := cmd.Flags().Lookup("model")
if modelFlag == nil {
t.Error("--model flag should exist")
}
configFlag := cmd.Flags().Lookup("config")
if configFlag == nil {
t.Error("--config flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
tests := []struct {
name string
models []string
want bool
reason string
}{
{"empty list", []string{}, false, "empty list has no local models"},
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
}
})
}
}
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
// PreRunE should be nil when passed nil
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
})
}
}

226
cmd/config/opencode.go Normal file
View File

@@ -0,0 +1,226 @@
package config
import (
"encoding/json"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
}
return paths
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
continue
}
models[model] = map[string]any{
"name": model,
"_launch": true,
}
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
}
return false
}

507
cmd/config/opencode_test.go Normal file
View File

@@ -0,0 +1,507 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
cleanup := func() {
os.RemoveAll(configDir)
os.RemoveAll(stateDir)
}
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
if provider["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve other models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("update existing model", func(t *testing.T) {
cleanup()
o.Edit([]string{"llama3.2"})
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["keybindings"] == nil {
t.Error("keybindings was removed")
}
})
t.Run("model state - insert at index 0", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
})
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
if len(state["favorite"].([]any)) != 1 {
t.Error("favorite was modified")
}
if state["variant"].(map[string]any)["a"] != "b" {
t.Error("variant was modified")
}
})
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
recent := state["recent"].([]any)
if len(recent) != 2 {
t.Errorf("expected 2 recent entries, got %d", len(recent))
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("remove model", func(t *testing.T) {
cleanup()
// First add two models
o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral")
// Then remove one by only selecting the other
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("preserve user customizations on managed models", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Add custom fields to the model entry (simulating user edits)
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
entry["_myPref"] = "custom-value"
entry["_myNum"] = 42
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit — should preserve custom fields
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
provider = cfg["provider"].(map[string]any)
ollama = provider["ollama"].(map[string]any)
models = ollama["models"].(map[string]any)
entry = models["llama3.2"].(map[string]any)
if entry["_myPref"] != "custom-value" {
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
}
if entry["_myNum"] != float64(42) {
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
}
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
}
})
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
cleanup()
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
// _launch marker should be added
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
}
// [Ollama] suffix should be stripped
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
t.Errorf("name suffix not stripped: got %q", entry["name"])
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Add a non-Ollama model manually
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
})
}
func assertOpenCodeModelExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatal("provider not found")
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
t.Fatal("ollama provider not found")
}
models, ok := ollama["models"].(map[string]any)
if !ok {
t.Fatal("models not found")
}
if models[model] == nil {
t.Errorf("model %s not found", model)
}
}
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
return // No provider means no model
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
return // No ollama means no model
}
models, ok := ollama["models"].(map[string]any)
if !ok {
return // No models means no model
}
if models[model] != nil {
t.Errorf("model %s should not exist but was found", model)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Fatal(err)
}
recent, ok := state["recent"].([]any)
if !ok {
t.Fatal("recent not found")
}
if index >= len(recent) {
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
}
entry, ok := recent[index].(map[string]any)
if !ok {
t.Fatal("entry is not a map")
}
if entry["providerID"] != providerID {
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
}
if entry["modelID"] != modelID {
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
}
}
// Edge case tests for opencode.go
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
// Should not panic - corrupted JSON should be treated as empty
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted config: %v", err)
}
// Verify valid JSON was created
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Errorf("resulting config is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted state: %v", err)
}
// Verify valid state was created
data, _ := os.ReadFile(statePath)
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("resulting state is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type provider failed: %v", err)
}
// Verify provider is now correct type
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
}
if provider["ollama"] == nil {
t.Error("ollama provider should be created")
}
}
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type recent failed: %v", err)
}
// The function should handle this gracefully
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
// recent should be properly set after setup
recent, ok := state["recent"].([]any)
if !ok {
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
} else if len(recent) == 0 {
t.Logf("Note: recent is empty (documenting behavior)")
}
}
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
os.WriteFile(configPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := o.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(configPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Model name with special characters (though unusual)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
// Verify it was stored correctly
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
// Model should be accessible
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := o.Models()
if len(models) > 0 {
t.Errorf("expected nil/empty for missing config, got %v", models)
}
}

499
cmd/config/selector.go Normal file
View File

@@ -0,0 +1,499 @@
package config
import (
"errors"
"fmt"
"io"
"os"
"strings"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiHideCursor = "\033[?25l"
ansiShowCursor = "\033[?25h"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiClearDown = "\033[J"
)
const maxDisplayedItems = 10
var errCancelled = errors.New("cancelled")
type selectItem struct {
Name string
Description string
}
type inputEvent int
const (
eventNone inputEvent = iota
eventEnter
eventEscape
eventUp
eventDown
eventTab
eventBackspace
eventChar
)
type selectState struct {
items []selectItem
filter string
selected int
scrollOffset int
}
func newSelectState(items []selectItem) *selectState {
return &selectState{items: items}
}
func (s *selectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.selected = 0
s.scrollOffset = 0
}
case eventUp:
if s.selected > 0 {
s.selected--
if s.selected < s.scrollOffset {
s.scrollOffset = s.selected
}
}
case eventDown:
if s.selected < len(filtered)-1 {
s.selected++
if s.selected >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.selected - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.selected = 0
s.scrollOffset = 0
}
return false, "", nil
}
type multiSelectState struct {
items []selectItem
itemIndex map[string]int
filter string
highlighted int
scrollOffset int
checked map[int]bool
checkOrder []int
focusOnButton bool
}
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
s := &multiSelectState{
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
s.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := s.itemIndex[name]; ok {
s.checked[idx] = true
s.checkOrder = append(s.checkOrder, idx)
}
}
return s
}
func (s *multiSelectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *multiSelectState) toggleItem() {
filtered := s.filtered()
if len(filtered) == 0 || s.highlighted >= len(filtered) {
return
}
item := filtered[s.highlighted]
origIdx := s.itemIndex[item.Name]
if s.checked[origIdx] {
delete(s.checked, origIdx)
for i, idx := range s.checkOrder {
if idx == origIdx {
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
break
}
}
} else {
s.checked[origIdx] = true
s.checkOrder = append(s.checkOrder, origIdx)
}
}
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if s.focusOnButton && len(s.checkOrder) > 0 {
var res []string
for _, idx := range s.checkOrder {
res = append(res, s.items[idx].Name)
}
return true, res, nil
} else if !s.focusOnButton {
s.toggleItem()
}
case eventTab:
if len(s.checkOrder) > 0 {
s.focusOnButton = !s.focusOnButton
}
case eventEscape:
return true, nil, errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
case eventUp:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted > 0 {
s.highlighted--
if s.highlighted < s.scrollOffset {
s.scrollOffset = s.highlighted
}
}
case eventDown:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted < len(filtered)-1 {
s.highlighted++
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
return false, nil, nil
}
func (s *multiSelectState) selectedCount() int {
return len(s.checkOrder)
}
// Terminal I/O handling
type terminalState struct {
fd int
oldState *term.State
}
func enterRawMode() (*terminalState, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return nil, err
}
fmt.Fprint(os.Stderr, ansiHideCursor)
return &terminalState{fd: fd, oldState: oldState}, nil
}
func (t *terminalState) restore() {
fmt.Fprint(os.Stderr, ansiShowCursor)
term.Restore(t.fd, t.oldState)
}
func clearLines(n int) {
if n > 0 {
fmt.Fprintf(os.Stderr, "\033[%dA", n)
fmt.Fprint(os.Stderr, ansiClearDown)
}
}
func parseInput(r io.Reader) (inputEvent, byte, error) {
buf := make([]byte, 3)
n, err := r.Read(buf)
if err != nil {
return 0, 0, err
}
switch {
case n == 1 && buf[0] == 13:
return eventEnter, 0, nil
case n == 1 && (buf[0] == 3 || buf[0] == 27):
return eventEscape, 0, nil
case n == 1 && buf[0] == 9:
return eventTab, 0, nil
case n == 1 && buf[0] == 127:
return eventBackspace, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
return eventUp, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
return eventDown, 0, nil
case n == 1 && buf[0] >= 32 && buf[0] < 127:
return eventChar, buf[0], nil
}
return eventNone, 0, nil
}
// Rendering
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
prefix := " "
if idx == s.selected {
prefix = " " + ansiBold + "> "
}
if item.Description != "" {
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
} else {
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
return lineCount
}
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := s.itemIndex[item.Name]
checkbox := "[ ]"
if s.checked[origIdx] {
checkbox = "[x]"
}
prefix := " "
suffix := ""
if idx == s.highlighted && !s.focusOnButton {
prefix = "> "
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
suffix = " " + ansiGray + "(default)" + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
fmt.Fprintf(w, "\r\n")
lineCount++
count := s.selectedCount()
switch {
case count == 0:
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
case s.focusOnButton:
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
default:
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
}
lineCount++
return lineCount
}
// selectPrompt prompts the user to select a single item from a list.
func selectPrompt(prompt string, items []selectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return "", err
}
defer ts.restore()
state := newSelectState(items)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return "", err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return "", err
}
return result, nil
}
render()
}
}
// multiSelectPrompt prompts the user to select multiple items from a list.
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return nil, err
}
defer ts.restore()
state := newMultiSelectState(items, preChecked)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return nil, err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return nil, err
}
return result, nil
}
render()
}
}
func confirmPrompt(prompt string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}
func filterItems(items []selectItem, filter string) []selectItem {
if filter == "" {
return items
}
var result []selectItem
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}

913
cmd/config/selector_test.go Normal file
View File

@@ -0,0 +1,913 @@
package config
import (
"bytes"
"strings"
"testing"
)
func TestFilterItems(t *testing.T) {
items := []selectItem{
{Name: "llama3.2:latest"},
{Name: "qwen2.5:7b"},
{Name: "deepseek-v3:cloud"},
{Name: "GPT-OSS:20b"},
}
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
result := filterItems(items, "")
if len(result) != len(items) {
t.Errorf("expected %d items, got %d", len(items), len(result))
}
})
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
result := filterItems(items, "LLAMA")
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
t.Errorf("expected llama3.2:latest, got %v", result)
}
})
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
result := filterItems(items, "gpt")
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
t.Errorf("expected GPT-OSS:20b, got %v", result)
}
})
t.Run("PartialMatch", func(t *testing.T) {
result := filterItems(items, "deep")
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
t.Errorf("expected deepseek-v3:cloud, got %v", result)
}
})
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
result := filterItems(items, "nonexistent")
if len(result) != 0 {
t.Errorf("expected 0 items, got %d", len(result))
}
})
}
func TestSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState", func(t *testing.T) {
s := newSelectState(items)
if s.selected != 0 {
t.Errorf("expected selected=0, got %d", s.selected)
}
if s.filter != "" {
t.Errorf("expected empty filter, got %q", s.filter)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item1" || err != nil {
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
s := newSelectState(items)
s.filter = "item3"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item3" || err != nil {
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != "" || err != errCancelled {
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Down_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventDown, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventDown, 0)
if s.selected != 2 {
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
}
})
t.Run("Up_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventUp, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventChar, 'i')
s.handleInput(eventChar, 't')
s.handleInput(eventChar, 'e')
s.handleInput(eventChar, 'm')
s.handleInput(eventChar, '2')
if s.filter != "item2" {
t.Errorf("expected filter='item2', got %q", s.filter)
}
filtered := s.filtered()
if len(filtered) != 1 || filtered[0].Name != "item2" {
t.Errorf("expected [item2], got %v", filtered)
}
})
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventChar, 'x')
if s.selected != 0 {
t.Errorf("expected selected=0 after typing, got %d", s.selected)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventBackspace, 0)
if s.filter != "" {
t.Errorf("expected filter='', got %q", s.filter)
}
})
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.selected = 2
s.handleInput(eventBackspace, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
}
})
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
// maxDisplayedItems is 10, so with 15 items we need to scroll
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
// move down 12 times (past the 10-item viewport)
for range 12 {
s.handleInput(eventDown, 0)
}
if s.selected != 12 {
t.Errorf("expected selected=12, got %d", s.selected)
}
if s.scrollOffset != 3 {
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
}
})
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
s.selected = 5
s.scrollOffset = 5
s.handleInput(eventUp, 0)
if s.selected != 4 {
t.Errorf("expected selected=4, got %d", s.selected)
}
if s.scrollOffset != 4 {
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
}
})
}
func TestMultiSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected, got %d", s.selectedCount())
}
if s.focusOnButton {
t.Error("expected focusOnButton=false initially")
}
})
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item3"})
if s.selectedCount() != 2 {
t.Errorf("expected 2 selected, got %d", s.selectedCount())
}
if !s.checked[1] || !s.checked[2] {
t.Error("expected item2 and item3 to be checked")
}
})
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
// order matters: first checked = default model
s := newMultiSelectState(items, []string{"item3", "item1"})
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
}
})
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
if s.selectedCount() != 1 {
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
}
})
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.toggleItem()
if !s.checked[0] {
t.Error("expected item1 to be checked after toggle")
}
})
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.toggleItem()
if s.checked[0] {
t.Error("expected item1 to be unchecked after toggle")
}
})
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
s.highlighted = 1 // toggle item2
s.toggleItem()
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
// should be [0, 2] (item1, item3) with item2 removed
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
}
})
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventEnter, 0)
if !s.checked[0] {
t.Error("expected item1 to be checked after enter")
}
})
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if !done || err != nil {
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
}
// result should preserve selection order
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
t.Errorf("expected [item2, item1], got %v", result)
}
})
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if done || result != nil || err != nil {
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after tab")
}
})
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("tab should not focus button when nothing selected")
}
})
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after first tab")
}
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("expected focus back on list after second tab")
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != nil || err != errCancelled {
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
t.Error("expected item2 (idx 1) to be default (first checked)")
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected item1 (idx 0) to NOT be default")
}
})
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected isDefault=false when nothing checked")
}
})
t.Run("Down_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventDown, 0)
if s.highlighted != 1 {
t.Errorf("expected highlighted=1, got %d", s.highlighted)
}
})
t.Run("Up_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.highlighted = 1
s.handleInput(eventUp, 0)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
})
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.focusOnButton = true
s.handleInput(eventDown, 0)
if s.focusOnButton {
t.Error("expected focus to return to list on arrow key")
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventChar, 'x')
if s.filter != "x" {
t.Errorf("expected filter='x', got %q", s.filter)
}
})
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newMultiSelectState(manyItems, nil)
s.highlighted = 10
s.scrollOffset = 5
s.handleInput(eventChar, 'x')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.filter = "x"
s.focusOnButton = true
s.handleInput(eventBackspace, 0)
if s.focusOnButton {
t.Error("expected focusOnButton=false after backspace")
}
})
}
func TestParseInput(t *testing.T) {
t.Run("Enter", func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{13}))
if err != nil || event != eventEnter || char != 0 {
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
}
})
t.Run("Escape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape, got %v", event)
}
})
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{3}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
}
})
t.Run("Tab", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{9}))
if err != nil || event != eventTab {
t.Errorf("expected eventTab, got %v", event)
}
})
t.Run("Backspace", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{127}))
if err != nil || event != eventBackspace {
t.Errorf("expected eventBackspace, got %v", event)
}
})
t.Run("UpArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
if err != nil || event != eventUp {
t.Errorf("expected eventUp, got %v", event)
}
})
t.Run("DownArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
if err != nil || event != eventDown {
t.Errorf("expected eventDown, got %v", event)
}
})
t.Run("PrintableChars", func(t *testing.T) {
tests := []struct {
name string
char byte
}{
{"lowercase", 'a'},
{"uppercase", 'Z'},
{"digit", '5'},
{"space", ' '},
{"tilde", '~'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
if err != nil || event != eventChar || char != tt.char {
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
}
})
}
})
}
func TestRenderSelect(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "first item"},
{Name: "item2"},
}
t.Run("ShowsPromptAndItems", func(t *testing.T) {
s := newSelectState(items)
var buf bytes.Buffer
lineCount := renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Select:") {
t.Error("expected prompt in output")
}
if !strings.Contains(output, "item1") {
t.Error("expected item1 in output")
}
if !strings.Contains(output, "first item") {
t.Error("expected description in output")
}
if !strings.Contains(output, "item2") {
t.Error("expected item2 in output")
}
if lineCount != 3 { // 1 prompt + 2 items
t.Errorf("expected 3 lines, got %d", lineCount)
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
}
})
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
// 15 items - 10 displayed = 5 more
if !strings.Contains(buf.String(), "5 more") {
t.Error("expected '5 more' indicator")
}
})
}
func TestRenderMultiSelect(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
}
t.Run("ShowsCheckboxes", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "[x]") {
t.Error("expected checked checkbox [x]")
}
if !strings.Contains(output, "[ ]") {
t.Error("expected unchecked checkbox [ ]")
}
})
t.Run("ShowsDefaultMarker", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "(default)") {
t.Error("expected (default) marker for first checked item")
}
})
t.Run("ShowsSelectedCount", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "2 selected") {
t.Error("expected '2 selected' in output")
}
})
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
s := newMultiSelectState(items, nil)
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "Select at least one") {
t.Error("expected 'Select at least one' helper text")
}
})
}
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
// Edge case tests for selector.go
// TestSelectState_SingleItem verifies that single item list works without crash.
// List with only one item should still work.
func TestSelectState_SingleItem(t *testing.T) {
items := []selectItem{{Name: "only-one"}}
s := newSelectState(items)
// Down should do nothing (already at bottom)
s.handleInput(eventDown, 0)
if s.selected != 0 {
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
}
// Up should do nothing (already at top)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
}
// Enter should select the only item
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "only-one" || err != nil {
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
}
}
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
// List with exactly maxDisplayedItems items should not scroll.
func TestSelectState_ExactlyMaxItems(t *testing.T) {
items := make([]selectItem, maxDisplayedItems)
for i := range items {
items[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(items)
// Move to last item
for range maxDisplayedItems - 1 {
s.handleInput(eventDown, 0)
}
if s.selected != maxDisplayedItems-1 {
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
// Should not scroll when exactly at max
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
}
// One more down should do nothing
s.handleInput(eventDown, 0)
if s.selected != maxDisplayedItems-1 {
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
}
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
// User typing "model.v1" shouldn't match "modelsv1".
func TestFilterItems_RegexSpecialChars(t *testing.T) {
items := []selectItem{
{Name: "model.v1"},
{Name: "modelsv1"},
{Name: "model-v1"},
}
// Filter with dot should only match literal dot
result := filterItems(items, "model.v1")
if len(result) != 1 {
t.Errorf("expected 1 exact match, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "model.v1" {
t.Errorf("expected 'model.v1', got %s", result[0].Name)
}
// Other regex special chars should be literal too
items2 := []selectItem{
{Name: "test[0]"},
{Name: "test0"},
{Name: "test(1)"},
}
result2 := filterItems(items2, "test[0]")
if len(result2) != 1 || result2[0].Name != "test[0]" {
t.Errorf("expected only 'test[0]', got %v", result2)
}
}
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
// itemIndex uses name as key - duplicates cause collision. This documents
// the current behavior: the last index for a duplicate name is stored
func TestMultiSelectState_DuplicateNames(t *testing.T) {
// Duplicate names - this is an edge case that shouldn't happen in practice
items := []selectItem{
{Name: "duplicate"},
{Name: "duplicate"},
{Name: "unique"},
}
s := newMultiSelectState(items, nil)
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
// When there are duplicates, only the last occurrence's index is stored
if s.itemIndex["duplicate"] != 1 {
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
}
// Toggle item at highlighted=0 (first "duplicate")
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
// So it actually toggles the SECOND duplicate item, not the first
s.toggleItem()
// This documents the potentially surprising behavior:
// We toggled at highlighted=0, but itemIndex lookup returned 1
if !s.checked[1] {
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
}
if s.checked[0] {
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
}
}
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
// Prevents index-out-of-bounds on next keystroke
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newSelectState(items)
s.selected = 2 // Select "cherry"
// Type a filter that removes cherry from results
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
// Selection should reset to 0
if s.selected != 0 {
t.Errorf("expected selected=0 after filter, got %d", s.selected)
}
filtered := s.filtered()
if len(filtered) != 2 {
t.Errorf("expected 2 filtered items, got %d", len(filtered))
}
}
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
// Model names might contain unicode characters
func TestFilterItems_UnicodeCharacters(t *testing.T) {
items := []selectItem{
{Name: "llama-日本語"},
{Name: "模型-chinese"},
{Name: "émoji-🦙"},
{Name: "regular-model"},
}
t.Run("filter japanese", func(t *testing.T) {
result := filterItems(items, "日本")
if len(result) != 1 || result[0].Name != "llama-日本語" {
t.Errorf("expected llama-日本語, got %v", result)
}
})
t.Run("filter chinese", func(t *testing.T) {
result := filterItems(items, "模型")
if len(result) != 1 || result[0].Name != "模型-chinese" {
t.Errorf("expected 模型-chinese, got %v", result)
}
})
t.Run("filter emoji", func(t *testing.T) {
result := filterItems(items, "🦙")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
t.Run("filter accented char", func(t *testing.T) {
result := filterItems(items, "émoji")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
}
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newMultiSelectState(items, nil)
s.highlighted = 2 // Highlight "cherry"
// Type a filter that removes cherry
s.handleInput(eventChar, 'a')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
}
}
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
// Empty list should be handled gracefully.
func TestMultiSelectState_EmptyItems(t *testing.T) {
s := newMultiSelectState([]selectItem{}, nil)
// Toggle should not panic on empty list
s.toggleItem()
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
}
// Render should handle empty list
var buf bytes.Buffer
lineCount := renderMultiSelect(&buf, "Select:", s)
if lineCount == 0 {
t.Error("renderMultiSelect should produce output even for empty list")
}
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' for empty list")
}
}
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
func TestSelectState_RenderWithDescriptions(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "First item description"},
{Name: "item2", Description: ""},
{Name: "item3", Description: "Third item"},
}
s := newSelectState(items)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "First item description") {
t.Error("expected description to be rendered")
}
if !strings.Contains(output, "item2") {
t.Error("expected item without description to be rendered")
}
}

View File

@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
scanner.Prompt.UseAlt = true
continue
}

View File

@@ -6,6 +6,10 @@ import (
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
kv["tokenizer.ggml.pre"] = "glm4"
return kv
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
}
}
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
if kvFirst {
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
return nil, err
}
if extractK {
// Slice K: [n_head, qk_nope, kv_lora_rank]
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
// Transpose to [n_head, kv_lora_rank, qk_nope]
tt, err = tensor.Transpose(tt, 0, 2, 1)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
} else {
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
numHeads := int(p.NumAttentionHeads)
kvFirst := true
if len(t.Shape()) == 2 {
switch {
case int(t.Shape()[0]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
numHeads = int(t.Shape()[1]) / kvPerHead
}
kvFirst = true
case int(t.Shape()[1]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
numHeads = int(t.Shape()[0]) / kvPerHead
}
kvFirst = false
default:
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
}
}
kTensor := t.Clone()
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
WriterTo: kTensor,
})
vTensor := t.Clone()
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
WriterTo: vTensor,
})
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),

View File

@@ -4,16 +4,6 @@ title: Anthropic compatibility
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
### Recommended models
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
Download a model before use:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ollama pull qwen3-coder
```
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
```shell
ollama pull glm-4.7:cloud
```
### Quick setup
```shell
ollama launch claude
```
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Set the environment variables and run Claude Code:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints

View File

@@ -8,6 +8,47 @@ title: CLI Reference
ollama run gemma3
```
### Launch integrations
```
ollama launch
```
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
#### Supported integrations
- **OpenCode** - Open-source coding assistant
- **Claude Code** - Anthropic's agentic coding tool
- **Codex** - OpenAI's coding assistant
- **Droid** - Factory's AI coding agent
#### Examples
Launch an integration interactively:
```
ollama launch
```
Launch a specific integration:
```
ollama launch claude
```
Launch with a specific model:
```
ollama launch claude --model qwen3-coder
```
Configure without launching:
```
ollama launch droid --config
```
#### Multiline input
For multiline input, you can wrap text with `"""`:

View File

@@ -3,8 +3,6 @@ title: Cloud
sidebarTitle: Cloud
---
<Info>Ollama's cloud is currently in preview.</Info>
## Cloud Models
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.

View File

@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
The default context length in Ollama is 4096 tokens.
</Note>
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
## Setting context length
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
### CLI
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
```
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
```
### Check allocated context length and model offloading

View File

@@ -102,18 +102,20 @@
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
"/integrations/clawdbot",
"/integrations/cline",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",
"/integrations/zed",
"/integrations/roo-code",
"/integrations/jetbrains",
"/integrations/marimo",
"/integrations/n8n",
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
"/integrations/opencode",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
]
},
{

View File

@@ -9,7 +9,7 @@ sidebarTitle: Welcome
<CardGroup cols={2}>
<Card title="Quickstart" icon="rocket" href="/quickstart">
Get up and running with your first model
Get up and running with your first model or integrate Ollama with your favorite tools
</Card>
<Card
title="Download Ollama"

View File

@@ -4,7 +4,7 @@ title: Claude Code
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
## Usage with Ollama
### Quick setup
```shell
ollama launch claude
```
To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_API_KEY=""
export ANTHROPIC_BASE_URL=http://localhost:11434
```
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -0,0 +1,48 @@
---
title: Clawdbot
---
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [Clawdbot](https://clawd.bot/)
```bash
npm install -g clawdbot@latest
```
Then run the onboarding wizard:
```bash
clawdbot onboard --install-daemon
```
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch clawdbot
```
This configures Clawdbot to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch clawdbot --config
```
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -13,7 +13,21 @@ npm install -g @openai/codex
## Usage with Ollama
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
### Quick setup
```
ollama launch codex
```
To configure without launching:
```shell
ollama launch codex --config
```
### Manual setup
To use `codex` with Ollama, use the `--oss` flag:

View File

@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch droid
```
To configure without launching:
```shell
ollama launch droid --config
```
### Manual setup
Add a local configuration block to `~/.factory/config.json`:
```json
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
}
```
Run `droid` in a new terminal to load the new settings.
Run `droid` in a new terminal to load the new settings.

View File

@@ -0,0 +1,106 @@
---
title: OpenCode
---
OpenCode is an open-source AI coding assistant that runs in your terminal.
## Install
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install.sh | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch opencode
```
To configure without launching:
```shell
ollama launch opencode --config
```
### Manual setup
Add a configuration block to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder"
}
}
}
}
}
```
## Cloud Models
`glm-4.7:cloud` is the recommended model for use with OpenCode.
Add the cloud configuration to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama Cloud",
"options": {
"baseURL": "https://ollama.com/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
Run `opencode` in a new terminal to load the new settings.

View File

@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="CLI">
Open a terminal and run the command:
```
```sh
ollama run gemma3
```
</Tab>
<Tab title="cURL">
```
```sh
ollama pull gemma3
```
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="Python">
Start by downloading a model:
```
```sh
ollama pull gemma3
```
Then install Ollama's Python library:
```
```sh
pip install ollama
```
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
</Tabs>
See a full list of available models [here](https://ollama.com/models).
## Coding
For coding use cases, we recommend using the `glm-4.7-flash` model.
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
```sh
ollama pull glm-4.7-flash
```
Alternatively, you can use a more powerful cloud model (with full context length):
```sh
ollama pull glm-4.7:cloud
```
Use `ollama launch` to quickly set up a coding tool with Ollama models:
```sh
ollama launch
```
### Supported integrations
- [OpenCode](/integrations/opencode) - Open-source coding assistant
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
- [Codex](/integrations/codex) - OpenAI's coding assistant
- [Droid](/integrations/droid) - Factory's AI coding agent
### Launch with a specific model
```sh
ollama launch claude --model glm-4.7-flash
```
### Configure without launching
```sh
ollama launch claude --config
```

View File

@@ -38,6 +38,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -143,6 +144,7 @@ var (
"granite3.3",
"hermes3",
"internlm2",
"lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -263,6 +265,7 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",

View File

@@ -0,0 +1,309 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: nobody <>
Date: Sat, 24 Jan 2026 02:31:01 +0000
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).
Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups
CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
---
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
10 files changed, 64 insertions(+), 19 deletions(-)
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 7bd1044c1..3dea2205e 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- // Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
+#ifdef VOLTA_MMA_AVAILABLE
+ if (ncols1*ncols2 < 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // VOLTA_MMA_AVAILABLE
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe67..371be7442 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
}
if constexpr (DV <= 256) {
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 015540666..1693479cb 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
- GGML_ASSERT(gqa_ratio % 16 == 0);
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ GGML_ASSERT(gqa_ratio % 4 == 0);
+ if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index 2074e954a..517993cb0 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index 24c64cf00..97b19c67a 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index 1ada657f1..989626dfa 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 86d4ffae2..173de7aac 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index f24270bb1..7b5ee968c 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
- op->src[0]->ne[0] != 256) {
- return false;
- }
- if (op->src[0]->ne[0] == 576) {
- // DeepSeek sizes
- // TODO: disabled for now, until optmized
+ op->src[0]->ne[0] != 256 &&
+ op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index e99c1763f..80864f303 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- int32_t nsg = 4;
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c98d269d1..d33c16079 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"crypto/sha256"
@@ -14,7 +14,7 @@ type Layer struct {
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
Status string `json:"-"`
}
const (
@@ -22,7 +22,7 @@ const (
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
blobs, err := BlobsPath("")
if err != nil {
return Layer{}, err
}
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
MediaType: mediatype,
Digest: digest,
Size: n,
status: fmt.Sprintf("%s %s", status, digest),
Status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest)
blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
status: fmt.Sprintf("using existing layer %s", digest),
Status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest)
blob, err := BlobsPath(l.Digest)
if err != nil {
return nil, err
}
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
}
}
blob, err := GetBlobsPath(l.Digest)
blob, err := BlobsPath(l.Digest)
if err != nil {
return err
}

View File

@@ -1,10 +1,9 @@
package server
package manifest
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
return
}
func (m *Manifest) Digest() string {
return m.digest
}
func (m *Manifest) FileInfo() os.FileInfo {
return m.fi
}
// ReadConfigJSON reads and unmarshals a config layer as JSON.
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
blobPath, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
data, err := os.ReadFile(blobPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
}
return fmt.Errorf("config %q not found in manifest", configPath)
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return err
}
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := GetBlobsPath(layer.Digest)
blob, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return nil, err
}
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return err
}
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return nil, err
}

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"encoding/json"

95
manifest/paths.go Normal file
View File

@@ -0,0 +1,95 @@
package manifest
import (
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
var ErrInvalidDigestFormat = errors.New("invalid digest format")
func Path() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PathForName returns the path to the manifest file for a specific model name.
func PathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := Path()
if err != nil {
return "", err
}
return filepath.Join(manifests, n.Filepath()), nil
}
func BlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PruneDirectory removes empty directories recursively.
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}

View File

@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
c.Next()
}
}
func ImageEditsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageEditRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.Image == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
return
}
genReq, err := openai.FromImageEditRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}
func TestImageEditsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
// Base64-encoded test image (1x1 pixel PNG)
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
testCases := []testCase{
{
name: "image edit basic",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
},
},
{
name: "image edit with size",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
Width: 512,
Height: 768,
},
},
{
name: "image edit missing prompt",
body: `{
"model": "test-model",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing model",
body: `{
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing image",
body: `{
"model": "test-model",
"prompt": "make it blue"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "image is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}

View File

@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major => swap A and B.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
return;
}
#endif // VOLTA_MMA_AVAILABLE
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
return;
}
}
if constexpr (DV <= 256) {

View File

@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
}
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

View File

@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {

View File

@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);

View File

@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -39,6 +39,13 @@ type Model interface {
Config() config
}
// Validator is an optional interface that models can implement to perform
// validation after tensors have been loaded. If validation fails, model
// loading will fail with the returned error.
type Validator interface {
Validate() error
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
if validator, ok := m.(Validator); ok {
if err := validator.Validate(); err != nil {
return nil, err
}
}
return m, nil
}

View File

@@ -1,6 +1,7 @@
package glm4moelite
import (
"errors"
"math"
"github.com/ollama/ollama/fs"
@@ -11,6 +12,8 @@ import (
"github.com/ollama/ollama/model/input"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
type Options struct {
numExpertsUsed int
numExperts int
@@ -47,7 +50,9 @@ type Attention struct {
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
KB *nn.Linear `gguf:"attn_k_b"`
VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
// MLA absorption: absorb K projection into query
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
key := kRot.Concat(ctx, kPass, 0)
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
@@ -217,7 +223,6 @@ func New(c fs.Config) (model.Model, error) {
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
@@ -236,7 +241,7 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
@@ -279,6 +284,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Validate() error {
for _, layer := range m.Layers {
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
return ErrOldModelFormat
}
}
return nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))

View File

@@ -0,0 +1,73 @@
package glm4moelite
import (
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidate(t *testing.T) {
tests := []struct {
name string
model *Model
wantErr bool
}{
{
name: "valid model with KB and VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
},
},
wantErr: false,
},
{
name: "missing KB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{VB: &nn.Linear{}}},
},
},
wantErr: true,
},
{
name: "missing VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{KB: &nn.Linear{}}},
},
},
wantErr: true,
},
{
name: "missing both KB and VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{}},
},
},
wantErr: true,
},
{
name: "nil Attention is ok",
model: &Model{
Layers: []Layer{
{Attention: nil},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.model.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && err != ErrOldModelFormat {
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
@@ -17,12 +18,34 @@ const (
ministralCollectingToolArgs
)
// ministralEvent represents an event emitted during parsing
type ministralEvent interface {
isMinistralEvent()
}
type ministralEventContent struct {
content string
}
type ministralEventThinking struct {
thinking string
}
type ministralEventToolCall struct {
name string
args string // raw JSON string
}
func (ministralEventContent) isMinistralEvent() {}
func (ministralEventThinking) isMinistralEvent() {}
func (ministralEventToolCall) isMinistralEvent() {}
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
currentTool *api.Tool
pendingToolName string // stores tool name while collecting args
}
func (p *MinistralParser) HasToolSupport() bool {
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
return nil, fmt.Errorf("tool '%s' not found", n)
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
const (
ministralToolCallsTag = "[TOOL_CALLS]"
ministralThinkTag = "[THINK]"
ministralThinkEndTag = "[/THINK]"
ministralArgsTag = "[ARGS]"
)
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. The second return value indicates
// whether to keep looping (true when state transitions, false when waiting
// for more data).
func (p *MinistralParser) eat() ([]ministralEvent, bool) {
var events []ministralEvent
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
bufStr := p.buffer.String()
// Check for [TOOL_CALLS] tag
if strings.Contains(bufStr, ministralToolCallsTag) {
split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
return events, true
}
// Check for [THINK] tag
if strings.Contains(bufStr, ministralThinkTag) {
split := strings.SplitN(bufStr, ministralThinkTag, 2)
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
if len(before) > 0 {
events = append(events, ministralEventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
return events, true
}
// Check for partial tag overlap with [TOOL_CALLS] or [THINK]
overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
overlapThink := overlap(bufStr, ministralThinkTag)
maxOverlap := max(overlapToolCalls, overlapThink)
if maxOverlap > 0 {
// Withhold the potential partial tag
beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
trailingWS := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWS
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
}
// No tag found: emit content but withhold trailing whitespace
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventContent{content: unambiguous})
}
return events, false
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralThinkEndTag) {
split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
thinkingContent := split[0]
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
p.buffer.WriteString(after)
if len(thinkingContent) > 0 {
events = append(events, ministralEventThinking{thinking: thinkingContent})
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(before), &args); err != nil {
// todo - throw a better error
return "", "", calls, err
}
p.state = ministralCollectingContent
return events, true
}
call := api.ToolCall{
// Check for partial overlap with [/THINK]
if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, ministralEventThinking{thinking: unambiguous})
}
return events, false
}
// No tag found: emit all thinking content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, ministralEventThinking{thinking: bufStr})
}
return events, false
case ministralCollectingToolName:
bufStr := p.buffer.String()
if strings.Contains(bufStr, ministralArgsTag) {
split := strings.SplitN(bufStr, ministralArgsTag, 2)
toolName := split[0]
after := split[1]
p.pendingToolName = toolName
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = ministralCollectingToolArgs
return events, true
}
// Wait for more data
return events, false
case ministralCollectingToolArgs:
bufStr := p.buffer.String()
jsonEnd := findJSONEnd(bufStr)
if jsonEnd != -1 {
jsonStr := bufStr[:jsonEnd+1]
remaining := bufStr[jsonEnd+1:]
events = append(events, ministralEventToolCall{
name: p.pendingToolName,
args: jsonStr,
})
p.pendingToolName = ""
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = ministralCollectingContent
return events, true
}
// Wait for more data
return events, false
default:
panic("unexpected ministral event")
}
}
// parseEvents loops calling eat() until it returns false
func (p *MinistralParser) parseEvents() []ministralEvent {
var all []ministralEvent
keepLooping := true
for keepLooping {
var events []ministralEvent
events, keepLooping = p.eat()
all = append(all, events...)
}
return all
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentBuilder, thinkingBuilder strings.Builder
var toolCalls []api.ToolCall
for _, event := range events {
switch e := event.(type) {
case ministralEventContent:
contentBuilder.WriteString(e.content)
case ministralEventThinking:
thinkingBuilder.WriteString(e.thinking)
case ministralEventToolCall:
// Validate tool exists
tool, toolErr := toolByName(p.tools, e.name)
if toolErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
}
// Parse JSON arguments
var args api.ToolCallFunctionArguments
if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Name: tool.Function.Name,
Arguments: args,
},
}
calls = append(calls, call)
return "", "", calls, nil
})
}
return "", "", calls, nil
}
return p.buffer.String(), thinking, calls, nil
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
}
// findJSONEnd finds the index of the closing brace that completes a JSON object.
// It properly handles nested objects, arrays, and strings (including escaped characters).
// Returns -1 if the JSON is not yet complete.
func findJSONEnd(s string) int {
depth := 0
inString := false
escaped := false
for i, r := range s {
if inString {
switch {
case escaped:
// If the previous character was a backslash, skip this character
escaped = false
case r == '\\':
// Mark the next character as escaped
escaped = true
case r == '"':
// End of string literal
inString = false
}
continue
}
switch r {
case '"':
// Start of string literal
inString = true
case '{', '[':
// Increase nesting level for objects and arrays
depth++
case '}', ']':
// Decrease nesting level
depth--
if depth == 0 {
// Reached the end of the root JSON structure
return i
}
}
}
return -1
}

View File

@@ -0,0 +1,545 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestMinistralParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []ministralEvent
}
cases := []struct {
desc string
tools []api.Tool
steps []step
think bool // whether to enable thinking support
}{
// Content streaming
{
desc: "simple content",
steps: []step{
{input: "Hello, how can I help you?", wantEvents: []ministralEvent{
ministralEventContent{content: "Hello, how can I help you?"},
}},
},
},
{
desc: "streaming content word by word",
steps: []step{
{input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
{input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
{input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
},
},
// Simple tool calls
{
desc: "simple tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
}},
},
},
{
desc: "tool call with nested object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "tool call with deeply nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
steps: []step{
{input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
}},
},
},
{
desc: "tool call with array of objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
steps: []step{
{input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
}},
},
},
{
desc: "tool call with escaped quotes in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
steps: []step{
{input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
}},
},
},
{
desc: "tool call with braces inside string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
steps: []step{
{input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
}},
},
},
{
desc: "empty JSON object",
tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
steps: []step{
{input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "no_args", args: `{}`},
}},
},
},
{
desc: "JSON with newlines in string",
tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
steps: []step{
{input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
}},
},
},
{
desc: "backslash in string value",
tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
steps: []step{
{input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
}},
},
},
// Content after tool call
{
desc: "content after tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// NOTE: It's unclear if this is valid Ministral output, but the parser
// currently treats text after a tool call as regular content. This test
// documents that behavior so we notice if it changes.
{input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": 1}`},
ministralEventContent{content: "some content after"},
}},
},
},
// Multiple tool calls
{
desc: "multiple tool calls in sequence",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "get_weather"}},
{Function: api.ToolFunction{Name: "get_time"}},
},
steps: []step{
{input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
}},
},
},
{
desc: "multiple tool calls streamed separately",
tools: []api.Tool{
{Function: api.ToolFunction{Name: "tool_a"}},
{Function: api.ToolFunction{Name: "tool_b"}},
},
steps: []step{
{input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
}},
{input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
}},
},
},
// Streaming tool calls
{
desc: "streaming tool call with nested objects",
tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
steps: []step{
{input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
{input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
{input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
{input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
{input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
{input: `]}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
}},
},
},
{
desc: "streaming with incomplete JSON waits for completion",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
{input: `"a": {`, wantEvents: []ministralEvent{}},
{input: `"b": 1`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{}},
{input: `}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
}},
},
},
// Partial tag handling
{
desc: "partial tool tag fakeout",
steps: []step{
{input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
{input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
},
},
{
desc: "tool call tag split across chunks",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "[TOOL_", wantEvents: []ministralEvent{}},
{input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "content before tool call",
tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
steps: []step{
{input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "hello"},
ministralEventToolCall{name: "get_weather", args: `{}`},
}},
},
},
{
desc: "whitespace between content and tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "tabs and newlines before tool call are trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "non-breaking space before tool call is trimmed",
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
// \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
{input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
{
desc: "whitespace before THINK tag is trimmed",
steps: []step{
{input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "after"},
}},
},
},
{
desc: "trailing whitespace withheld then emitted",
steps: []step{
{input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
},
},
{
desc: "trailing newline withheld then emitted",
steps: []step{
{input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
{input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
},
},
// Thinking support
{
desc: "thinking content",
think: true,
steps: []step{
{input: "thinking here[/THINK]", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking here"},
}},
{input: "content after", wantEvents: []ministralEvent{
ministralEventContent{content: "content after"},
}},
},
},
{
desc: "thinking with whitespace after end tag",
think: true,
steps: []step{
{input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "my thoughts"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "non-breaking space after think end tag is trimmed",
think: true,
steps: []step{
// \u00a0 is non-breaking space
{input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "response"},
}},
},
},
{
desc: "partial think end tag",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
},
},
{
desc: "think tag fakeout",
think: true,
steps: []step{
{input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
{input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
},
},
{
desc: "thinking then tool call",
think: true,
tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
steps: []step{
{input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
ministralEventThinking{thinking: "let me think"},
ministralEventToolCall{name: "test", args: `{}`},
}},
},
},
// Content then THINK tag transition
{
desc: "content then think tag",
steps: []step{
{input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
ministralEventContent{content: "content"},
ministralEventThinking{thinking: "thinking"},
ministralEventContent{content: "more"},
}},
},
},
// Unicode handling
{
desc: "unicode content",
steps: []step{
{input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
ministralEventContent{content: "你好 🌍 مرحبا"},
}},
},
},
{
desc: "unicode in tool args",
tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
steps: []step{
{input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
}},
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := MinistralParser{}
parser.hasThinkingSupport = tc.think
parser.Init(tc.tools, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestMinistralParser_Errors(t *testing.T) {
t.Run("unknown tool returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
if err == nil {
t.Fatal("expected error for unknown tool")
}
})
t.Run("invalid JSON returns error", func(t *testing.T) {
p := &MinistralParser{}
p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
_, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
})
}
func TestFindJSONEnd(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "simple object",
input: `{"a": 1}`,
expected: 7,
},
{
name: "nested object",
input: `{"a": {"b": 2}}`,
expected: 14,
},
{
name: "array inside object",
input: `{"items": [1, 2, 3]}`,
expected: 19,
},
{
name: "braces in string",
input: `{"template": "Hello {name}!"}`,
expected: 28,
},
{
name: "escaped quotes",
input: `{"msg": "say \"hi\""}`,
expected: 20,
},
{
name: "incomplete object",
input: `{"a": {"b": 1}`,
expected: -1,
},
{
name: "deeply nested",
input: `{"a": {"b": {"c": {"d": 1}}}}`,
expected: 28,
},
{
name: "object with trailing content",
input: `{"a": 1} extra`,
expected: 7,
},
{
name: "array",
input: `[{"a": 1}, {"b": 2}]`,
expected: 19,
},
{
name: "escaped backslash before quote",
input: `{"path": "C:\\"}`,
expected: 15,
},
{
name: "empty string",
input: "",
expected: -1,
},
{
name: "no opening brace",
input: "hello world",
expected: -1,
},
{
name: "only opening brace",
input: "{",
expected: -1,
},
{
name: "unclosed string",
input: `{"key": "unclosed`,
expected: -1,
},
{
name: "double escaped backslash then quote",
input: `{"path": "C:\\\\"}`,
expected: 17,
},
{
name: "unicode in key and value",
input: `{"키": "값"}`,
expected: 13,
},
{
name: "nested arrays",
input: `{"matrix": [[1, 2], [3, 4]]}`,
expected: 27,
},
{
name: "mixed nesting",
input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
expected: 31,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findJSONEnd(tt.input)
if result != tt.expected {
t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
}
})
}
}
func TestMinistralParser_HasToolSupport(t *testing.T) {
p := &MinistralParser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestMinistralParser_HasThinkingSupport(t *testing.T) {
p := &MinistralParser{hasThinkingSupport: false}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
p = &MinistralParser{hasThinkingSupport: true}
if !p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return true")
}
}

View File

@@ -3,6 +3,7 @@ package parsers
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
@@ -114,3 +115,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
sb.WriteString(after)
return before, after // return events
}
// overlap returns the longest overlap between the suffix of s and the prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}

View File

@@ -11,7 +11,6 @@ import (
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`

View File

@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
Data: data,
}
}
// ImageEditRequest is an OpenAI-compatible image edit request.
type ImageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image string `json:"image"` // Base64-encoded image data
Size string `json:"size,omitempty"` // e.g., "1024x1024"
Seed *int64 `json:"seed,omitempty"`
}
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Decode the input image
if r.Image != "" {
imgData, err := decodeImageURL(r.Image)
if err != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
}
req.Images = append(req.Images, imgData)
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
if req.Options == nil {
req.Options = map[string]any{}
}
req.Options["seed"] = *r.Seed
}
return req, nil
}

View File

@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
func TestFromImageEditRequest_Basic(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if result.Prompt != "make it blue" {
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
}
if len(result.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Images))
}
}
func TestFromImageEditRequest_WithSize(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Size: "512x768",
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Width != 512 {
t.Errorf("expected width 512, got %d", result.Width)
}
if result.Height != 768 {
t.Errorf("expected height 768, got %d", result.Height)
}
}
func TestFromImageEditRequest_WithSeed(t *testing.T) {
seed := int64(12345)
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Seed: &seed,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options == nil {
t.Fatal("expected options to be set")
}
if result.Options["seed"] != seed {
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
}
}
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: "not-valid-base64",
}
_, err := FromImageEditRequest(req)
if err == nil {
t.Error("expected error for invalid image")
}
}

View File

@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
var currentLineBuf []rune
// draining tracks if we're processing buffered input from cooked mode.
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
// We treat \n from cooked mode as submit, not multiline.
// We check Buffered() after the first read since the bufio buffer is
// empty until then. This is compatible with """ multiline mode in
// interactive.go since each Readline() call is independent.
var draining, stopDraining bool
for {
// Apply deferred state change from previous iteration
if stopDraining {
draining = false
stopDraining = false
}
// don't show placeholder when pasting unless we're in multiline mode
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
if buf.IsEmpty() && showPlaceholder {
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
r, err := i.Terminal.Read()
// After reading, check if there's more buffered data. If so, we're
// processing cooked-mode input. Once buffer empties, the current
// char is the last buffered one (still drain it), then stop next iteration.
if i.Terminal.reader.Buffered() > 0 {
draining = true
} else if draining {
stopDraining = true
}
if buf.IsEmpty() {
fmt.Print(ClearToEOL)
}
@@ -232,15 +255,20 @@ func (i *Instance) Readline() (string, error) {
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
// If not draining cooked-mode input, treat as multiline
if !draining {
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
}
// Draining cooked-mode input: treat \n as submit
fallthrough
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {

View File

@@ -14,8 +14,8 @@
VOL_NAME=${VOL_NAME:-"Ollama"}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CFLAGS="-mmacosx-version-min=14.0"
export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
set -e

View File

@@ -56,6 +56,12 @@ function checkEnv {
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
if (-not $env:CGO_CFLAGS) {
$env:CGO_CFLAGS = "-O3"
}
if (-not $env:CGO_CXXFLAGS) {
$env:CGO_CXXFLAGS = "-O3"
}
Write-Output "Checking version"
if (!$env:VERSION) {
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- resp
}
oldManifest, _ := ParseNamedManifest(name)
oldManifest, _ := manifest.ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := GetBlobsPath(files[fn])
blobPath, err := manifest.BlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
layer, err := NewLayer(t, mediaType)
layer, err := manifest.NewLayer(t, mediaType)
if err != nil {
return nil, err
}
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
var layers []manifest.Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
for _, layer := range layers {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
if layer.Status != "" {
fn(api.ProgressResponse{Status: layer.Status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, *configLayer, layers); err != nil {
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
return err
}
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
blob, err := GetBlobsPath(layer.Digest)
blob, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := NewLayer(temp, layer.MediaType)
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return detectChatTemplate(layers)
}
func removeLayer(layers []Layer, mediatype string) []Layer {
return slices.DeleteFunc(layers, func(layer Layer) bool {
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
})
}
func setTemplate(layers []Layer, t string) ([]Layer, error) {
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
}
blob := strings.NewReader(t)
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
return layers, nil
}
func setSystem(layers []Layer, s string) ([]Layer, error) {
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
return layers, nil
}
func setLicense(layers []Layer, l string) ([]Layer, error) {
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
blob := strings.NewReader(l)
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
return layers, nil
}
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
if p == nil {
p = make(map[string]any)
}
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
continue
}
digestPath, err := GetBlobsPath(layer.Digest)
digestPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
return layers, nil
}
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}

View File

@@ -24,6 +24,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
type downloadOpts struct {
mp ModelPath
n model.Name
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
@@ -465,10 +467,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
}
fp, err := GetBlobsPath(opts.digest)
fp, err := manifest.BlobsPath(opts.digest)
if err != nil {
return false, err
}
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
requestURL := opts.n.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -24,6 +23,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
@@ -75,12 +75,6 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImage}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
if err == nil {
@@ -274,44 +268,22 @@ func (m *Model) String() string {
return modelfile.String()
}
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
f, err := os.Open(fp)
if err != nil {
return nil, "", err
}
defer f.Close()
sha256sum := sha256.New()
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
return nil, "", err
}
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
n := model.ParseName(name)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
m := &Model{
Name: n.String(),
ShortName: n.DisplayShortest(),
Digest: mf.Digest(),
Template: template.DefaultTemplate,
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if mf.Config.Digest != "" {
filename, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
return nil, err
}
@@ -322,29 +294,29 @@ func GetModel(name string) (*Model, error) {
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err
}
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
for _, layer := range mf.Layers {
filename, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.ParentModel = layer.From
m.ModelPath = filename
m.ParentModel = layer.From
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
@@ -352,7 +324,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.Template, err = template.Parse(string(bts))
m.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
@@ -362,7 +334,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.System = string(bts)
m.System = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -371,7 +343,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
@@ -381,7 +353,7 @@ func GetModel(name string) (*Model, error) {
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
@@ -389,11 +361,11 @@ func GetModel(name string) (*Model, error) {
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
m.License = append(m.License, string(bts))
}
}
return model, nil
return m, nil
}
func CopyModel(src, dst model.Name) error {
@@ -408,7 +380,7 @@ func CopyModel(src, dst model.Name) error {
return nil
}
manifests, err := GetManifestPath()
manifests, err := manifest.Path()
if err != nil {
return err
}
@@ -437,7 +409,7 @@ func CopyModel(src, dst model.Name) error {
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
manifests, err := Manifests(true)
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
@@ -452,7 +424,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
fp, err := manifest.BlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
@@ -468,7 +440,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
p, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -483,9 +455,9 @@ func PruneLayers() error {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
_, err := manifest.BlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
@@ -510,63 +482,30 @@ func PruneLayers() error {
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
manifest, _, err := GetManifest(mp)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
manifestPath, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -574,7 +513,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -582,17 +521,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
@@ -611,44 +550,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
manifest, _, err := GetManifest(mp)
existingMf, err := manifest.ParseNamedManifest(n)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
} else {
for _, l := range manifest.Layers {
for _, l := range existingMf.Layers {
deleteMap[l.Digest] = struct{}{}
}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
mf, err := pullModelManifest(ctx, n, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -658,7 +597,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
n: n,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
@@ -677,7 +616,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
fp, err := GetBlobsPath(layer.Digest)
fp, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return err
}
@@ -692,16 +631,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
delete(deleteMap, mf.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -728,9 +667,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
func hasTensorLayers(layers []manifest.Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
if layer.MediaType == manifest.MediaTypeImageTensor {
return true
}
}
@@ -738,7 +677,7 @@ func hasTensorLayers(layers []Layer) bool {
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -747,12 +686,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
destDir, err := GetBlobsPath("")
destDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -784,7 +723,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Repository: n.DisplayNamespaceModel(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
@@ -795,12 +734,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -812,7 +751,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -822,12 +761,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
srcDir, err := GetBlobsPath("")
srcDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -864,13 +803,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
ManifestRef: n.Tag,
Repository: n.DisplayNamespaceModel(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
@@ -880,7 +819,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m Manifest
var m manifest.Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
@@ -1042,7 +981,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
fp, err := manifest.BlobsPath(digest)
if err != nil {
return err
}

View File

@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityImage},
},
{
name: "model with image and vision capability (image editing)",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image", "vision"},
},
},
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -20,19 +21,19 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
Layer
manifest.Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := ParseNamedManifest(name)
m, err := manifest.ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
m, err = ParseNamedManifest(name)
m, err = manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
blobpath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}

View File

@@ -1,146 +0,0 @@
package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrModelPathInvalid = errors.New("invalid model path")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}

View File

@@ -1,153 +0,0 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
tempDir := t.TempDir()
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(tempDir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", tempDir)
got, err := GetBlobsPath(tc.digest)
require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -95,6 +95,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
} else if strings.Contains(name, "attn_k_b.weight") ||
strings.Contains(name, "attn_v_b.weight") ||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
strings.Contains(name, "attn_q_a.weight") ||
strings.Contains(name, "attn_q_b.weight") {
// MLA tensors need higher precision to avoid quality degradation
newType = fsggml.TensorTypeQ8_0
} else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown
n_layer := qs.nFfnDown

View File

@@ -39,6 +39,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
existing, err := manifest.Manifests(true)
if err != nil {
return zero, err
}
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
m, err := ParseNamedManifest(n)
m, err := manifest.ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, ErrModelPathInvalid
return nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
manifest, err := ParseNamedManifest(name)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -1147,7 +1148,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
ModifiedAt: mf.FileInfo().ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
@@ -1214,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1223,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1285,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := Manifests(true)
ms, err := manifest.Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1316,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Digest: m.Digest(),
ModifiedAt: m.FileInfo().ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1376,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1392,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := GetBlobsPath(ib)
p, err := manifest.BlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1410,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1428,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
layer, err := NewLayer(c.Request.Body, "")
layer, err := manifest.NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1603,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoint
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -1628,7 +1630,7 @@ func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
blobsDir, err := GetBlobsPath("")
blobsDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -1637,7 +1639,7 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
if _, err := Manifests(false); err != nil {
if _, err := manifest.Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
@@ -1645,12 +1647,12 @@ func Serve(ln net.Listener) error {
return err
}
manifestsPath, err := GetManifestPath()
manifestsPath, err := manifest.Path()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
}
}
@@ -2506,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
// Check streaming preference
isStreaming := req.Stream == nil || *req.Stream
contentType := "application/x-ndjson"
if !isStreaming {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
// Get seed from options if provided
var seed int64
@@ -2522,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
}
var images []llm.ImageData
for i, imgData := range req.Images {
images = append(images, llm.ImageData{ID: i, Data: imgData})
}
var streamStarted bool
var finalResponse api.GenerateResponse
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
@@ -2552,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if !isStreaming {
finalResponse = res
return
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
@@ -2561,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if !isStreaming {
c.JSON(http.StatusOK, finalResponse)
}
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
manifest, err := ParseNamedManifest(model.ParseName("child"))
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if manifest.Config.Digest == "" {
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
configPath, err := GetBlobsPath(manifest.Config.Digest)
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err)
}
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := WriteManifest(n, config, []Layer{config}); err != nil {
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
t.Fatal(err)
}

View File

@@ -19,7 +19,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func (mockRunner) Ping(_ context.Context) error { return nil }
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
return mock, nil
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
}
})
}
func TestGenerateWithImages(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("images passed to completion request", func(t *testing.T) {
testImage := []byte("test-image-data")
mock.CompletionResponse.Content = "Image processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Describe this image",
Images: []api.ImageData{testImage},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify images were passed to the completion request
if len(mock.CompletionRequest.Images) != 1 {
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
t.Errorf("image data mismatch in completion request")
}
if mock.CompletionRequest.Images[0].ID != 0 {
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
}
})
t.Run("multiple images passed to completion request", func(t *testing.T) {
testImage1 := []byte("test-image-1")
testImage2 := []byte("test-image-2")
mock.CompletionResponse.Content = "Images processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Compare these images",
Images: []api.ImageData{testImage1, testImage2},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify both images were passed
if len(mock.CompletionRequest.Images) != 2 {
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
t.Errorf("first image data mismatch")
}
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
t.Errorf("second image data mismatch")
}
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
t.Errorf("expected image IDs 0 and 1, got %d and %d",
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
}
})
t.Run("no images when none provided", func(t *testing.T) {
mock.CompletionResponse.Content = "No images"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify no images in completion request
if len(mock.CompletionRequest.Images) != 0 {
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
})
}
// TestImageGenerateStreamFalse tests that image generation respects stream=false
// and returns a single JSON response instead of streaming ndjson.
func TestImageGenerateStreamFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
t.Fatal(err)
}
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
Stream: &streamFalse,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
}
body := w.Body.String()
lines := strings.Split(strings.TrimSpace(body), "\n")
if len(lines) != 1 {
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
}
var resp api.GenerateResponse
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp.Image != "base64image" {
t.Errorf("expected image 'base64image', got %q", resp.Image)
}
if !resp.Done {
t.Errorf("expected done=true")
}
}

View File

@@ -21,12 +21,14 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
Layer
manifest.Layer
Total int64
Completed atomic.Int64
@@ -51,7 +53,7 @@ const (
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := GetBlobsPath(b.Digest)
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
return err
}
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
requestURL.RawQuery = values.Encode()
}
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
p, err := GetBlobsPath(b.Digest)
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
b.err = err
return
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"path/filepath"
"strings"
)
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
const MissingPart = "!MISSING!"
const (
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultProtocolScheme = "https"
)
// DefaultName returns a name with the default values for the host, namespace,
// and tag parts. The model and digest parts are empty.
// tag, and protocol scheme parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
// - The default protocol scheme is ("https")
func DefaultName() Name {
return Name{
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
ProtocolScheme: defaultProtocolScheme,
}
}
@@ -87,10 +91,11 @@ func (k partKind) String() string {
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
type Name struct {
Host string
Namespace string
Model string
Tag string
Host string
Namespace string
Model string
Tag string
ProtocolScheme string
}
// ParseName parses and assembles a Name from a name string. The
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
}
scheme, host, ok := strings.Cut(s, "://")
if !ok {
if ok {
n.ProtocolScheme = scheme
} else {
host = scheme
}
n.Host = host
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, and tag parts of the two names,
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
return a
}
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
strings.EqualFold(n.Tag, o.Tag)
}
// BaseURL returns the base URL for the registry.
func (n Name) BaseURL() *url.URL {
return &url.URL{
Scheme: n.ProtocolScheme,
Host: n.Host,
}
}
// DisplayNamespaceModel returns the namespace and model joined by "/".
func (n Name) DisplayNamespaceModel() string {
var b strings.Builder
b.WriteString(n.Namespace)
b.WriteByte('/')
b.WriteString(n.Model)
return b.String()
}
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:

View File

@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
{
in: "scheme://host:port/namespace/model:tag",
want: Name{
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
ProtocolScheme: "scheme",
},
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
},

View File

@@ -11,9 +11,11 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
@@ -103,7 +105,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
layer, err := manifest.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
@@ -141,13 +143,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -169,7 +171,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -186,7 +188,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -209,10 +211,23 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
return fmt.Errorf("invalid model name: %s", modelName)
}
// TODO: find a better way to detect image input support
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
caps := capabilities
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
if data, err := os.ReadFile(modelIndex); err == nil {
var cfg struct {
ClassName string `json:"_class_name"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
caps = append(caps, "vision")
}
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Capabilities: caps,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
@@ -221,15 +236,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
// Convert LayerInfo to manifest.Layer
manifestLayers := make([]manifest.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
manifestLayers = append(manifestLayers, manifest.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
@@ -243,19 +258,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
manifestLayers = append(manifestLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
return manifest.WriteManifest(name, configLayer, manifestLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
@@ -263,7 +278,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
@@ -271,7 +286,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}

View File

@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// Supports both single-stream and dual-stream architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {

View File

@@ -10,7 +10,10 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
// Image paths can be included in the prompt and will be extracted automatically.
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Get options from flags (with env var defaults)
opts := DefaultOptions()
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
// Extract any image paths from the prompt
prompt, images, err := extractFileData(prompt)
if err != nil {
return err
}
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
// Check if it's a file path, not a command
args := strings.Fields(line)
isFile := false
for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) {
isFile = true
break
}
}
if !isFile {
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
continue
}
}
// Extract any image paths from the input
prompt, images, err := extractFileData(line)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Prompt: prompt,
Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
return false
}
}
// extractFileNames finds image file paths in the input string.
func extractFileNames(input string) []string {
// Regex to match file paths with image extensions
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
}
// extractFileData extracts image data from file paths found in the input.
// Returns the cleaned prompt (with file paths removed) and the image data.
func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input)
var imgs []api.ImageData
for _, fp := range filePaths {
// Normalize shell escapes
nfp := strings.ReplaceAll(fp, "\\ ", " ")
nfp = strings.ReplaceAll(nfp, "\\(", "(")
nfp = strings.ReplaceAll(nfp, "\\)", ")")
nfp = strings.ReplaceAll(nfp, "%20", " ")
data, err := getImageData(nfp)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return "", nil, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return strings.TrimSpace(input), imgs, nil
}
// getImageData reads and validates image data from a file.
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}
// Re-read the full file
return os.ReadFile(filePath)
}

View File

@@ -21,8 +21,6 @@ import (
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -61,14 +59,11 @@ func main() {
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
@@ -166,60 +161,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:

View File

@@ -7,6 +7,8 @@ import (
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
}
// DecodeImage decodes image bytes with EXIF orientation applied.
// Transparent images are composited onto a white background.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
return nil, err
}
img = flattenAlpha(img)
return applyOrientation(img, orientation), nil
}
// flattenAlpha composites an image onto a white background,
// removing any transparency. This is needed because image
// generation models don't handle alpha channels well.
func flattenAlpha(img image.Image) image.Image {
if _, ok := img.(*image.RGBA); !ok {
if _, ok := img.(*image.NRGBA); !ok {
// No alpha channel, return as-is
return img
}
}
bounds := img.Bounds()
dst := image.NewRGBA(bounds)
// Fill with white background
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
// Composite the image on top
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
return dst
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {

View File

@@ -33,6 +33,17 @@ type ModelManifest struct {
BlobDir string
}
func DefaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// DefaultManifestDir returns the manifest storage directory.
// Respects OLLAMA_MODELS.
func DefaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) {
@@ -50,7 +61,7 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
return &ModelManifest{
Manifest: &manifest,
BlobDir: filepath.Join(envconfig.Models(), "blobs"),
BlobDir: DefaultBlobDir(),
}, nil
}
@@ -81,7 +92,7 @@ func resolveManifestPath(modelName string) string {
name = parts[1]
}
return filepath.Join(envconfig.Models(), "manifests", host, namespace, name, tag)
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest.
@@ -150,6 +161,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
return false
}
// TotalTensorSize returns the total size in bytes of all tensor layers.
func (m *ModelManifest) TotalTensorSize() int64 {
var total int64
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
total += layer.Size
}
}
return total
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string

View File

@@ -3,27 +3,55 @@ package imagegen
import (
"path/filepath"
"testing"
"github.com/ollama/ollama/envconfig"
)
func TestLoadManifestRespectsOLLAMA_MODELS(t *testing.T) {
t.Setenv("OLLAMA_MODELS", "/custom/models/path")
// Verify envconfig.Models() returns our custom path
if got := envconfig.Models(); got != "/custom/models/path" {
t.Fatalf("envconfig.Models() = %q, want %q", got, "/custom/models/path")
func TestTotalTensorSize(t *testing.T) {
m := &ModelManifest{
Manifest: &Manifest{
Layers: []ManifestLayer{
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
},
},
}
// LoadManifest will fail (no manifest exists), but we can verify
// the error message contains the custom path
_, err := LoadManifest("test-model")
if err == nil {
t.Fatal("expected error, got nil")
}
expectedPath := filepath.Join("/custom/models/path", "manifests", "registry.ollama.ai", "library", "test-model", "latest")
if got := err.Error(); got != "read manifest: open "+expectedPath+": no such file or directory" {
t.Errorf("error = %q, want path to contain %q", got, expectedPath)
got := m.TotalTensorSize()
want := int64(6000)
if got != want {
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
}
}
func TestTotalTensorSizeEmpty(t *testing.T) {
m := &ModelManifest{
Manifest: &Manifest{
Layers: []ManifestLayer{},
},
}
if got := m.TotalTensorSize(); got != 0 {
t.Errorf("TotalTensorSize() = %d, want 0", got)
}
}
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
modelsDir := filepath.Join(t.TempDir(), "models")
// Simulate packaged/systemd environment
t.Setenv("OLLAMA_MODELS", modelsDir)
t.Setenv("HOME", "/usr/share/ollama")
// Manifest dir must respect OLLAMA_MODELS
wantManifest := filepath.Join(modelsDir, "manifests")
if got := DefaultManifestDir(); got != wantManifest {
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
}
// Blob dir must respect OLLAMA_MODELS
wantBlobs := filepath.Join(modelsDir, "blobs")
if got := DefaultBlobDir(); got != wantBlobs {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
}
}

View File

@@ -16,18 +16,9 @@ import (
"runtime"
)
// GB is a convenience constant for gigabytes.
const GB = 1024 * 1024 * 1024
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 20 * GB, // ~20GB for Flux
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
}
}
// CheckMemoryRequirements validates that there's enough memory for image generation.
// Returns nil if memory is sufficient, or an error if not.
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
required := EstimateVRAM(modelName)
if availableMemory < required {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
required/GB, availableMemory/GB)
}
return nil
}
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
return ""
}
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
className := DetectModelType(modelName)
if estimate, ok := modelVRAMEstimates[className]; ok {
return estimate
}
return 21 * GB
}
// DetectModelType reads model_index.json and returns the model type.
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.

View File

@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
}
}
func TestCheckMemoryRequirements(t *testing.T) {
tests := []struct {
name string
availableMemory uint64
wantErr bool
}{
{
name: "sufficient memory",
availableMemory: 32 * GB,
wantErr: false,
},
{
name: "exactly enough memory",
availableMemory: 21 * GB,
wantErr: false,
},
{
name: "insufficient memory",
availableMemory: 16 * GB,
wantErr: true,
},
{
name: "zero memory",
availableMemory: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a non-existent model name which will default to 21GB estimate
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
if (err != nil) != tt.wantErr {
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 20 * GB,
}
for name, expectedVRAM := range expected {
if actual, ok := modelVRAMEstimates[name]; !ok {
t.Errorf("Missing VRAM estimate for %s", name)
} else if actual != expectedVRAM {
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
}
}
}
func TestEstimateVRAMDefault(t *testing.T) {
// Non-existent model should return default 21GB
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
if vram != 21*GB {
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")

View File

@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
})
}
// GenerateImageWithInputs implements runner.ImageEditModel interface.
// It generates an image conditioned on the provided input images for image editing.
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
InputImages: inputImages,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048

View File

@@ -1,87 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,367 +0,0 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,218 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,135 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -1,174 +0,0 @@
//go:build mlx
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

View File

@@ -1,868 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Qwen-Image transformer configuration
type TransformerConfig struct {
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
NHeads int32 `json:"num_attention_heads"` // 24
HeadDim int32 `json:"attention_head_dim"` // 128
NLayers int32 `json:"num_layers"` // 60
InChannels int32 `json:"in_channels"` // 64
OutChannels int32 `json:"out_channels"` // 16
PatchSize int32 `json:"patch_size"` // 2
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
NormEps float32 `json:"norm_eps"` // 1e-6
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
GuidanceEmbeds bool `json:"guidance_embeds"` // false
}
// defaultTransformerConfig returns config for Qwen-Image transformer
func defaultTransformerConfig() *TransformerConfig {
return &TransformerConfig{
HiddenDim: 3072, // 24 * 128
NHeads: 24,
HeadDim: 128,
NLayers: 60,
InChannels: 64,
OutChannels: 16,
PatchSize: 2,
JointAttentionDim: 3584,
NormEps: 1e-6,
AxesDimsRope: []int32{16, 56, 56},
GuidanceEmbeds: false,
}
}
// TimestepEmbedder creates timestep embeddings
type TimestepEmbedder struct {
Linear1Weight *mlx.Array // [256, hidden_dim]
Linear1Bias *mlx.Array
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
Linear2Bias *mlx.Array
}
// newTimestepEmbedder creates a timestep embedder from weights
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
if err != nil {
return nil, err
}
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
if err != nil {
return nil, err
}
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
if err != nil {
return nil, err
}
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
if err != nil {
return nil, err
}
return &TimestepEmbedder{
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
Linear1Bias: linear1Bias,
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
Linear2Bias: linear2Bias,
}, nil
}
// Forward computes timestep embeddings
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
half := int32(128) // embedding_dim / 2
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
tExpanded := mlx.ExpandDims(t, 1)
args := mlx.Mul(tExpanded, freqsArr)
args = mlx.MulScalar(args, 1000.0) // scale
// [cos, sin] (flip_sin_to_cos=True)
sinArgs := mlx.Sin(args)
cosArgs := mlx.Cos(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
// MLP: linear1 -> silu -> linear2
h := mlx.Linear(embedding, te.Linear1Weight)
h = mlx.Add(h, te.Linear1Bias)
h = mlx.SiLU(h)
h = mlx.Linear(h, te.Linear2Weight)
h = mlx.Add(h, te.Linear2Bias)
return h
}
// JointAttention implements dual-stream joint attention
type JointAttention struct {
// Image projections
ToQ *mlx.Array
ToQB *mlx.Array
ToK *mlx.Array
ToKB *mlx.Array
ToV *mlx.Array
ToVB *mlx.Array
ToOut *mlx.Array
ToOutB *mlx.Array
NormQ *mlx.Array
NormK *mlx.Array
// Text (added) projections
AddQProj *mlx.Array
AddQProjB *mlx.Array
AddKProj *mlx.Array
AddKProjB *mlx.Array
AddVProj *mlx.Array
AddVProjB *mlx.Array
ToAddOut *mlx.Array
ToAddOutB *mlx.Array
NormAddQ *mlx.Array
NormAddK *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newJointAttention creates a joint attention layer
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
return &JointAttention{
ToQ: mlx.Transpose(toQ, 1, 0),
ToQB: toQB,
ToK: mlx.Transpose(toK, 1, 0),
ToKB: toKB,
ToV: mlx.Transpose(toV, 1, 0),
ToVB: toVB,
ToOut: mlx.Transpose(toOut, 1, 0),
ToOutB: toOutB,
NormQ: normQ,
NormK: normK,
AddQProj: mlx.Transpose(addQProj, 1, 0),
AddQProjB: addQProjB,
AddKProj: mlx.Transpose(addKProj, 1, 0),
AddKProjB: addKProjB,
AddVProj: mlx.Transpose(addVProj, 1, 0),
AddVProjB: addVProjB,
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
ToAddOutB: toAddOutB,
NormAddQ: normAddQ,
NormAddK: normAddK,
NHeads: cfg.NHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
}, nil
}
// Forward computes joint attention
// img: [B, L_img, D], txt: [B, L_txt, D]
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
D := imgShape[2]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// === Image Q/K/V ===
imgFlat := mlx.Reshape(img, B*Limg, D)
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
// QK norm (RMSNorm per head)
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
// Apply RoPE
if imgFreqs != nil {
qImg = applyRoPE(qImg, imgFreqs)
kImg = applyRoPE(kImg, imgFreqs)
}
// === Text Q/K/V ===
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
if txtFreqs != nil {
qTxt = applyRoPE(qTxt, txtFreqs)
kTxt = applyRoPE(kTxt, txtFreqs)
}
// Concatenate for joint attention: [txt, img] order
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
// Transpose to [B, nheads, L, head_dim]
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
// SDPA
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
// Transpose back and split
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
// Output projections
outImg = mlx.Reshape(outImg, B*Limg, D)
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
outImg = mlx.Reshape(outImg, B, Limg, D)
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
return outImg, outTxt
}
// applyRoPE applies rotary embeddings using complex multiplication
// x: [B, L, nheads, head_dim]
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
halfDim := headDim / 2
// Reshape x to pairs: [B, L, nheads, half, 2]
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
// Extract real/imag parts
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
xReal = mlx.Squeeze(xReal, 4)
xImag = mlx.Squeeze(xImag, 4)
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
freqReal = mlx.Squeeze(freqReal, 4)
freqImag = mlx.Squeeze(freqImag, 4)
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
// Interleave back
outReal = mlx.ExpandDims(outReal, 4)
outImag = mlx.ExpandDims(outImag, 4)
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
return mlx.Reshape(out, B, L, nheads, headDim)
}
// MLP implements GELU MLP (not GEGLU)
type MLP struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
OutWeight *mlx.Array
OutBias *mlx.Array
}
// newMLP creates a GELU MLP
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
outWeight, _ := weights.Get(prefix + ".net.2.weight")
outBias, _ := weights.Get(prefix + ".net.2.bias")
return &MLP{
ProjWeight: mlx.Transpose(projWeight, 1, 0),
ProjBias: projBias,
OutWeight: mlx.Transpose(outWeight, 1, 0),
OutBias: outBias,
}, nil
}
// Forward applies GELU MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
h = geluApprox(h)
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
}
// geluApprox implements approximate GELU
func geluApprox(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
inner = mlx.MulScalar(inner, sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// TransformerBlock is a single dual-stream transformer block
type TransformerBlock struct {
Attention *JointAttention
ImgMLP *MLP
TxtMLP *MLP
ImgModWeight *mlx.Array
ImgModBias *mlx.Array
TxtModWeight *mlx.Array
TxtModBias *mlx.Array
HiddenDim int32
NormEps float32
}
// newTransformerBlock creates a transformer block
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
attn, err := newJointAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
return &TransformerBlock{
Attention: attn,
ImgMLP: imgMLP,
TxtMLP: txtMLP,
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
ImgModBias: imgModBias,
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
TxtModBias: txtModBias,
HiddenDim: cfg.HiddenDim,
NormEps: cfg.NormEps,
}, nil
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
siluT := mlx.SiLU(temb)
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
imgModParts := splitMod6(imgMod, tb.HiddenDim)
txtModParts := splitMod6(txtMod, tb.HiddenDim)
// Pre-attention: norm + modulate
imgNorm := layerNormNoAffine(img, tb.NormEps)
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
txtNorm := layerNormNoAffine(txt, tb.NormEps)
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
// Joint attention
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
// Pre-MLP: norm + modulate
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
// MLP
mlpImg := tb.ImgMLP.Forward(imgNorm2)
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
return img, txt
}
// splitMod6 splits modulation into 6 parts each [B, 1, D]
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
shape := mod.Shape()
B := shape[0]
parts := make([]*mlx.Array, 6)
for i := int32(0); i < 6; i++ {
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
parts[i] = mlx.ExpandDims(part, 1)
}
return parts
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Qwen-Image transformer model
type Transformer struct {
Config *TransformerConfig
ImgIn *mlx.Array
ImgInBias *mlx.Array
TxtIn *mlx.Array
TxtInBias *mlx.Array
TxtNorm *mlx.Array
TEmbed *TimestepEmbedder
Layers []*TransformerBlock
NormOutWeight *mlx.Array
NormOutBias *mlx.Array
ProjOut *mlx.Array
ProjOutBias *mlx.Array
}
// Load loads the transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Qwen-Image transformer...")
cfg := defaultTransformerConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading input projections... ")
imgIn, _ := weights.Get("img_in.weight")
imgInBias, _ := weights.Get("img_in.bias")
txtIn, _ := weights.Get("txt_in.weight")
txtInBias, _ := weights.Get("txt_in.bias")
txtNorm, _ := weights.Get("txt_norm.weight")
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
m.ImgInBias = imgInBias
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
m.TxtInBias = txtInBias
m.TxtNorm = txtNorm
fmt.Println("✓")
fmt.Print(" Loading timestep embedder... ")
m.TEmbed, err = newTimestepEmbedder(weights)
if err != nil {
return fmt.Errorf("timestep embedder: %w", err)
}
fmt.Println("✓")
m.Layers = make([]*TransformerBlock, cfg.NLayers)
for i := int32(0); i < cfg.NLayers; i++ {
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
prefix := fmt.Sprintf("transformer_blocks.%d", i)
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
if err != nil {
return fmt.Errorf("layer %d: %w", i, err)
}
}
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
fmt.Print(" Loading output layers... ")
normOutWeight, _ := weights.Get("norm_out.linear.weight")
normOutBias, _ := weights.Get("norm_out.linear.bias")
projOut, _ := weights.Get("proj_out.weight")
projOutBias, _ := weights.Get("proj_out.bias")
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
m.NormOutBias = normOutBias
m.ProjOut = mlx.Transpose(projOut, 1, 0)
m.ProjOutBias = projOutBias
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadFromPath is a convenience function to load transformer from path
func LoadTransformerFromPath(path string) (*Transformer, error) {
m := &Transformer{}
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
return nil, err
}
return m, nil
}
// Forward runs the transformer
// img: [B, L_img, in_channels] patchified latents
// txt: [B, L_txt, joint_attention_dim] text embeddings
// t: [B] timesteps (0-1)
// imgFreqs, txtFreqs: RoPE frequencies
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
for _, layer := range tr.Layers {
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
}
// Final norm with modulation (AdaLayerNormContinuous)
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// ForwardWithCache runs the transformer with layer caching for speedup.
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between denoising steps, so we cache their
// outputs and reuse them on non-refresh steps.
//
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
// step: current denoising step (0-indexed)
// cacheInterval: refresh cache every N steps (e.g., 3)
// cacheLayers: number of shallow layers to cache (e.g., 15)
func (tr *Transformer) ForwardWithCache(
img, txt, t *mlx.Array,
imgFreqs, txtFreqs *mlx.Array,
stepCache *cache.StepCache,
step, cacheInterval, cacheLayers int,
) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
// Check if we should refresh the cache
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
for i, layer := range tr.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached outputs for shallow layers
imgH = stepCache.Get(i)
txtH = stepCache.Get2(i)
} else {
// Compute layer
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
// Cache shallow layers on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, imgH)
stepCache.Set2(i, txtH)
}
}
}
// Final norm with modulation (AdaLayerNormContinuous)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// RoPECache holds precomputed RoPE frequencies
type RoPECache struct {
ImgFreqs *mlx.Array // [L_img, head_dim]
TxtFreqs *mlx.Array // [L_txt, head_dim]
}
// PrepareRoPE computes RoPE for image and text sequences
// This matches Python's QwenEmbedRope with scale_rope=True
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := ComputeAxisFreqs(axesDims[0], theta)
freqsH := ComputeAxisFreqs(axesDims[1], theta)
freqsW := ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
// Image frequencies with scale_rope=True
imgLen := imgH * imgW
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
imgFreqsData := make([]float32, imgLen*headDim)
hHalf := imgH / 2
wHalf := imgW / 2
idx := int32(0)
for y := int32(0); y < imgH; y++ {
for x := int32(0); x < imgW; x++ {
// Frame = 0
for i := 0; i < len(freqsT)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
}
idx += int32(len(freqsT) * 2)
// Height: scale_rope pattern
hNegCount := imgH - hHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
}
}
idx += int32(len(freqsH) * 2)
// Width: scale_rope pattern
wNegCount := imgW - wHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
}
}
idx += int32(len(freqsW) * 2)
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies
maxVidIdx := max(hHalf, wHalf)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
halfDim := dim / 2
freqs := make([]float64, halfDim)
for i := int32(0); i < halfDim; i++ {
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
}
return freqs
}
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
table := make([][]float32, maxIdx)
for idx := int32(0); idx < maxIdx; idx++ {
var pos float64
if negative {
pos = float64(-maxIdx + int32(idx))
} else {
pos = float64(idx)
}
row := make([]float32, len(baseFreqs)*2)
for i, f := range baseFreqs {
angle := pos * f
row[i*2] = float32(math.Cos(angle))
row[i*2+1] = float32(math.Sin(angle))
}
table[idx] = row
}
return table
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// -> [B, pH, pW, C, 2, 2]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// -> [B, pH*pW, C*4]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
channels := shape[2] / (patchSize * patchSize)
pH := H / patchSize
pW := W / patchSize
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// -> [B, C, pH, 2, pW, 2]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// -> [B, C, H, W]
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
// Add temporal dimension for VAE: [B, C, 1, H, W]
return mlx.ExpandDims(x, 2)
}

Some files were not shown because too many files have changed in this diff Show More