Compare commits

...

14 Commits

Author SHA1 Message Date
Patrick Devine
49905784f1 fixup 2026-01-23 18:01:17 -08:00
Patrick Devine
a00721f586 add runner to glm 4.7 flash MLX implementation 2026-01-23 17:15:59 -08:00
Patrick Devine
98ca1c3904 x models: add glm 4.7 flash to mlx engine 2026-01-23 17:15:59 -08:00
Jeffrey Morgan
2eda97f1c3 Revert "model: add MLA absorption for glm4moelite (#13810)" (#13869)
This reverts commit 1044b0419a.
2026-01-23 17:14:15 -08:00
Jeffrey Morgan
66831dcf70 x/imagegen: fix image editing support (#13866)
- Fix panic in ollama show for image gen models (safe type assertion)
- Add vision capability for Flux2KleinPipeline models at create time
- Flatten transparent PNG images onto white background for better results
2026-01-23 15:37:17 -08:00
Jeffrey Morgan
1044b0419a model: add MLA absorption for glm4moelite (#13810)
* model: add MLA absorption for glm4moelite

Split the combined KV_B tensor into separate K_B and V_B tensors
during conversion, enabling MLA (Multi-head Latent Attention)
absorption which compresses the KV cache for improved efficiency.

* ggml: enable MLA flash attention for GLM-4.7-flash

Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).

Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups

CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4

* model: add compatibility validation for glm4moelite architecture
2026-01-23 14:47:42 -08:00
Parth Sareen
771d9280ec cmd: ollama config fix droid model name configuration (#13856) 2026-01-23 11:44:22 -08:00
Jeffrey Morgan
862bc0a3bf x/imagegen: respect stream=false in /api/generate (#13853)
When stream=false is set for image generation requests, return a single
JSON response instead of streaming multiple ndjson progress updates.
2026-01-22 22:16:39 -08:00
Jeffrey Morgan
c01608b6a1 x/imagegen: add image edit capabilities (#13846) 2026-01-22 20:35:08 -08:00
Parth Sareen
199c41e16e cmd: ollama config command to help configure integrations to use Ollama (#13712) 2026-01-22 20:17:11 -08:00
Jeffrey Morgan
3b3bf6c217 x/imagegen: replace memory estimation with actual weight size (#13848)
Remove static VRAM estimation (EstimateVRAM, CheckMemoryRequirements)
which wasn't helpful. Instead, report the actual tensor weight size
from the manifest for ollama ps.

- Remove memory estimation check from runner startup
- Remove EstimateVRAM, CheckMemoryRequirements, modelVRAMEstimates
- Add TotalTensorSize() to get actual weight size from manifest
- Use weight size for Server.vramSize instead of estimates

Note: This is better than showing 0 or inaccurate estimates, but the
weight size is a drastic underestimation of actual memory usage since
it doesn't account for activations, intermediate tensors, or MLX
overhead. Future work should query real-time memory from MLX
(e.g., MetalGetActiveMemory) for accurate reporting.
2026-01-22 18:32:41 -08:00
Parth Sareen
f52c21f457 fix: handle Enter key pressed during model loading (#13839) 2026-01-22 18:32:02 -08:00
Jeffrey Morgan
b5d0f72f16 x/imagegen: remove qwen_image and qwen_image_edit models (#13827)
Remove the Qwen image generation and image editing model packages
to clean up the codebase. These models will be reintroduced later.

- Delete x/imagegen/models/qwen_image/ (10 files)
- Delete x/imagegen/models/qwen_image_edit/ (5 files)
- Remove related CLI flags and imports from cmd/engine/main.go
- Update comments in cache/step.go to remove Qwen-specific references
2026-01-21 13:37:08 -08:00
Patrick Devine
148a1be0a3 Clean up the manifest and modelpath (#13807) 2026-01-21 11:46:17 -08:00
98 changed files with 10019 additions and 10536 deletions

View File

@@ -35,6 +35,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -1018,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
arch, _ := resp.ModelInfo["general.architecture"].(string)
if arch != "" {
rows = append(rows, []string{"", "architecture", arch})
}
var paramStr string
if resp.Details.ParameterSize != "" {
@@ -1029,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if paramStr != "" {
rows = append(rows, []string{"", "parameters", paramStr})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
@@ -2026,6 +2031,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.ConfigCmd(checkServerHeartbeat),
)
return rootCmd

36
cmd/config/claude.go Normal file
View File

@@ -0,0 +1,36 @@
package config
import (
"fmt"
"os"
"os/exec"
)
// Claude implements Runner for Claude Code integration
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
if model != "" {
return []string{"--model", model}
}
return nil
}
func (c *Claude) Run(model string) error {
if _, err := exec.LookPath("claude"); err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command("claude", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
return cmd.Run()
}

42
cmd/config/claude_test.go Normal file
View File

@@ -0,0 +1,42 @@
package config
import (
"slices"
"testing"
)
func TestClaudeIntegration(t *testing.T) {
c := &Claude{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Claude Code" {
t.Errorf("String() = %q, want %q", got, "Claude Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

61
cmd/config/codex.go Normal file
View File

@@ -0,0 +1,61 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
return args
}
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

28
cmd/config/codex_test.go Normal file
View File

@@ -0,0 +1,28 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

115
cmd/config/config.go Normal file
View File

@@ -0,0 +1,115 @@
// Package config provides integration configuration for external coding tools
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
}
return nil, err
}
var cfg config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
}
if cfg.Integrations == nil {
cfg.Integrations = make(map[string]*integration)
}
return &cfg, nil
}
func save(cfg *config) error {
path, err := configPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return writeWithBackup(path, data)
}
func saveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
}
return result, nil
}

373
cmd/config/config_test.go Normal file
View File

@@ -0,0 +1,373 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if len(config.Models) != len(models) {
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
}
for i, m := range models {
if config.Models[i] != m {
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
}
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-a" {
t.Errorf("expected model-a, got %s", defaultModel)
}
})
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
config := &integration{Models: []string{}}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "" {
t.Errorf("expected empty string, got %s", defaultModel)
}
})
t.Run("app name is case-insensitive", func(t *testing.T) {
saveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-x" {
t.Errorf("expected model-x, got %s", defaultModel)
}
})
t.Run("multiple integrations in single file", func(t *testing.T) {
saveIntegration("app1", []string{"model-1"})
saveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
defaultModel1 = config1.Models[0]
}
defaultModel2 := ""
if len(config2.Models) > 0 {
defaultModel2 = config2.Models[0]
}
if defaultModel1 != "model-1" {
t.Errorf("expected model-1, got %s", defaultModel1)
}
if defaultModel2 != "model-2" {
t.Errorf("expected model-2, got %s", defaultModel2)
}
})
}
func TestListIntegrations(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty when no integrations", func(t *testing.T) {
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 0 {
t.Errorf("expected 0 integrations, got %d", len(configs))
}
})
t.Run("returns all saved integrations", func(t *testing.T) {
saveIntegration("claude", []string{"model-1"})
saveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 2 {
t.Errorf("expected 2 integrations, got %d", len(configs))
}
})
}
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
}
func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := loadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
if config.Models == nil {
// nil is acceptable
} else if len(config.Models) != 0 {
t.Errorf("expected empty or nil models, got %v", config.Models)
}
}
func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := saveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
}
}
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}
if !os.IsNotExist(err) {
t.Logf("error type is os.ErrNotExist as expected: %v", err)
}
}
func TestConfigPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
path, err := configPath()
if err != nil {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
}
func TestLoad(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty config when file does not exist", func(t *testing.T) {
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg == nil {
t.Fatal("expected non-nil config")
}
if cfg.Integrations == nil {
t.Error("expected non-nil Integrations map")
}
if len(cfg.Integrations) != 0 {
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
}
})
t.Run("loads existing config", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["test"] == nil {
t.Fatal("expected test integration")
}
if len(cfg.Integrations["test"].Models) != 1 {
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
}
})
t.Run("returns error for corrupted JSON", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{corrupted`), 0o644)
_, err := load()
if err == nil {
t.Error("expected error for corrupted JSON")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("creates config file", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"test": {Models: []string{"model-a", "model-b"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
path, _ := configPath()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("config file was not created")
}
})
t.Run("round-trip preserves data", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"claude": {Models: []string{"llama3.2", "mistral"}},
"codex": {Models: []string{"qwen2.5"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
loaded, err := load()
if err != nil {
t.Fatal(err)
}
if len(loaded.Integrations) != 2 {
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
}
if loaded.Integrations["claude"] == nil {
t.Error("missing claude integration")
}
if len(loaded.Integrations["claude"].Models) != 2 {
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
}
})
}

184
cmd/config/droid.go Normal file
View File

@@ -0,0 +1,184 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidSettings represents the Droid settings.json file (only fields we use)
type droidSettings struct {
CustomModels []modelEntry `json:"customModels"`
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
}
type sessionSettings struct {
Model string `json:"model"`
ReasoningEffort string `json:"reasoningEffort"`
}
type modelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Provider string `json:"provider"`
MaxOutputTokens int `json:"maxOutputTokens"`
SupportsImages bool `json:"supportsImages"`
ID string `json:"id"`
Index int `json:"index"`
}
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (d *Droid) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".factory", "settings.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (d *Droid) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
settingsPath := filepath.Join(home, ".factory", "settings.json")
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
return err
}
// Read file once, unmarshal twice:
// map preserves unknown fields for writing back (including extra fields in model entries)
settingsMap := make(map[string]any)
var settings droidSettings
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settingsMap); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
if rawModels, ok := settingsMap["customModels"].([]any); ok {
for _, raw := range rawModels {
if m, ok := raw.(map[string]any); ok {
if m["apiKey"] != "ollama" {
nonOllamaModels = append(nonOllamaModels, raw)
}
}
}
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var newModels []any
var defaultModelID string
for i, model := range models {
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,
})
if i == 0 {
defaultModelID = modelID
}
}
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
// Update session default settings (preserve unknown fields in the nested object)
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
sessionSettings["reasoningEffort"] = "none"
}
settingsMap["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
}
func (d *Droid) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
var settings droidSettings
if err := json.Unmarshal(data, &settings); err != nil {
return nil
}
var result []string
for _, m := range settings.CustomModels {
if m.APIKey == "ollama" {
result = append(result, m.Model)
}
}
return result
}
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}

1302
cmd/config/droid_test.go Normal file
View File

File diff suppressed because it is too large Load Diff

99
cmd/config/files.go Normal file
View File

@@ -0,0 +1,99 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
func readJSONFile(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
func copyFile(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, info.Mode().Perm())
}
func backupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups")
}
func backupToTmp(srcPath string) (string, error) {
dir := backupDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
if err := copyFile(srcPath, backupPath); err != nil {
return "", err
}
return backupPath, nil
}
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func writeWithBackup(path string, data []byte) error {
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {
if !bytes.Equal(existingContent, data) {
backupPath, err = backupToTmp(path)
if err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("create temp failed: %w", err)
}
tmpPath := tmp.Name()
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("write failed: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("sync failed: %w", err)
}
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("close failed: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
_ = os.Remove(tmpPath)
if backupPath != "" {
_ = copyFile(backupPath, path)
}
return fmt.Errorf("rename failed: %w", err)
}
return nil
}

502
cmd/config/files_test.go Normal file
View File

@@ -0,0 +1,502 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return data
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir()
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var result map[string]string
if err := json.Unmarshal(content, &result); err != nil {
t.Fatal(err)
}
if result["key"] != "value" {
t.Errorf("expected value, got %s", result["key"])
}
})
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(backupDir())
if err != nil {
t.Fatal("backup directory not created")
}
var foundBackup bool
for _, entry := range entries {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
json.Unmarshal(backup, &backupData)
if backupData["original"] {
foundBackup = true
os.Remove(backupPath)
break
}
}
}
}
}
if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups")
}
current, _ := os.ReadFile(path)
var currentData map[string]bool
json.Unmarshal(current, &currentData)
if !currentData["updated"] {
t.Error("file doesn't contain updated data")
}
})
t.Run("no backup for new file", func(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
}
}
})
t.Run("no backup when content unchanged", func(t *testing.T) {
path := filepath.Join(tmpDir, "unchanged.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(backupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countBefore++
}
}
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(backupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countAfter++
}
}
if countAfter != countBefore {
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
}
})
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
path := filepath.Join(tmpDir, "timestamped.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
timestamp := name[len("timestamped.json."):]
for _, c := range timestamp {
if c < '0' || c > '9' {
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
}
}
found = true
os.Remove(filepath.Join(backupDir(), name))
break
}
}
if !found {
t.Error("backup file with timestamp not found")
}
})
}
// Edge case tests for files.go
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
// User could lose their config with no way to recover.
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
// Create original file
originalContent := []byte(`{"original": true}`)
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := backupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
t.Error("expected error when backup fails, got nil")
}
// Original file should be preserved
current, _ := os.ReadFile(path)
if string(current) != string(originalContent) {
t.Errorf("original file was modified despite backup failure: got %s", string(current))
}
}
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
// Common issue when config owned by root or wrong perms.
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
os.MkdirAll(readOnlyDir, 0o755)
os.Chmod(readOnlyDir, 0o444)
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
}
}
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
// Documents what happens if user symlinks their config file.
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := t.TempDir()
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
// Create real file and symlink
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
os.Symlink(realFile, symlink)
// Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
// The real file should be updated (symlink followed for temp file creation)
content, _ := os.ReadFile(symlink)
if string(content) != `{"v": 2}` {
t.Errorf("symlink target not updated correctly: got %s", string(content))
}
}
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir()
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
backupPath, err := backupToTmp(path)
if err != nil {
t.Fatalf("backupToTmp with special chars failed: %v", err)
}
// Verify backup exists and has correct content
content, err := os.ReadFile(backupPath)
if err != nil {
t.Fatalf("could not read backup: %v", err)
}
if string(content) != `{"test": true}` {
t.Errorf("backup content mismatch: got %s", string(content))
}
os.Remove(backupPath)
}
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
func TestCopyFile_PreservesPermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
// Create source with specific permissions
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile failed: %v", err)
}
srcInfo, _ := os.Stat(src)
dstInfo, _ := os.Stat(dst)
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
}
}
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
err := copyFile(src, dst)
if err == nil {
t.Error("expected error for nonexistent source, got nil")
}
}
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
}
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("could not read file: %v", err)
}
if len(content) != 0 {
t.Errorf("expected empty file, got %d bytes", len(content))
}
}
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
// cannot be read (for backup comparison) but directory is writable.
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
os.Chmod(path, 0o000)
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
}
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
// Verify final content
content, _ := os.ReadFile(path)
if string(content) != `{"v": 3}` {
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
}
// Verify at least one backup exists
entries, _ := os.ReadDir(backupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
backupCount++
}
}
if backupCount == 0 {
t.Error("expected at least one backup file from rapid writes")
}
}
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test modifies system temp directory")
}
// Create a file at the backup directory path
backupPath := backupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
defer func() {
os.Remove(backupPath)
os.MkdirAll(backupPath, 0o755)
}()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
}
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Count existing temp files
countTempFiles := func() int {
entries, _ := os.ReadDir(tmpDir)
count := 0
for _, e := range entries {
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
count++
}
}
return count
}
before := countTempFiles()
// Create a file, then make directory read-only to cause rename failure
path := filepath.Join(tmpDir, "orphan.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
// Make a subdirectory and try to write there after making parent read-only
subDir := filepath.Join(tmpDir, "subdir")
os.MkdirAll(subDir, 0o755)
subPath := filepath.Join(subDir, "config.json")
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
// Make subdir read-only after creating temp file would succeed but rename would fail
// This is tricky to test - the temp file is created in the same dir, so if we can't
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
// Force a failure by making the target a directory
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
}
}

362
cmd/config/integrations.go Normal file
View File

@@ -0,0 +1,362 @@
package config
import (
"context"
"errors"
"fmt"
"maps"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
// Runners execute the launching of a model with the integration - claude, codex
// Editors can edit config files (supports multi-model selection) - opencode, droid
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
// Editor can edit config files (supports multi-model selection)
type Editor interface {
// Paths returns the paths to the config files for the integration
Paths() []string
// Edit updates the config files for the integration with the given models
Edit(models []string) error
// Models returns the models currently configured for the integration
// TODO(parthsareen): add error return to Models()
Models() []string
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"codex": &Codex{},
"droid": &Droid{},
"opencode": &OpenCode{},
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
}
items = append(items, selectItem{Name: name, Description: description})
}
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
}
// ConfigCmd returns the cobra command for configuring integrations.
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var launchFlag bool
cmd := &cobra.Command{
Use: "config [INTEGRATION]",
Short: "Configure an external integration to use Ollama",
Long: `Configure an external application to use Ollama models.
Supported integrations:
claude Claude Code
codex Codex
droid Droid
opencode OpenCode
Examples:
ollama config
ollama config claude
ollama config droid --launch`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
var name string
if len(args) > 0 {
name = args[0]
} else {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// If --launch without --model, use saved config if available
if launchFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if slices.ContainsFunc(models, func(m string) bool {
return !strings.HasSuffix(m, "cloud")
}) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
}
if launchFlag {
return runIntegration(name, models[0])
}
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
return nil
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
return cmd
}

View File

@@ -0,0 +1,188 @@
package config
import (
"slices"
"strings"
"testing"
"github.com/spf13/cobra"
)
func TestIntegrationLookup(t *testing.T) {
tests := []struct {
name string
input string
wantFound bool
wantName string
}{
{"claude lowercase", "claude", true, "Claude Code"},
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
{"empty string", "", false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, found := integrations[strings.ToLower(tt.input)]
if found != tt.wantFound {
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
}
if found && r.String() != tt.wantName {
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
}
})
}
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
r, ok := integrations[name]
if !ok {
t.Fatalf("integration %q not found in registry", name)
}
if r.String() == "" {
t.Error("integration.String() should not be empty")
}
})
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string
models []string
want bool
}{
{"empty list", []string{}, false},
{"single local model", []string{"llama3.2"}, true},
{"single cloud model", []string{"cloud-model"}, false},
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
{"local model first", []string{"llama3.2", "cloud-model"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
}
})
}
}
func TestConfigCmd(t *testing.T) {
// Mock checkServerHeartbeat that always succeeds
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := ConfigCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "config [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
modelFlag := cmd.Flags().Lookup("model")
if modelFlag == nil {
t.Error("--model flag should exist")
}
launchFlag := cmd.Flags().Lookup("launch")
if launchFlag == nil {
t.Error("--launch flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
tests := []struct {
name string
models []string
want bool
reason string
}{
{"empty list", []string{}, false, "empty list has no local models"},
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
}
})
}
}
func TestConfigCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := ConfigCmd(nil)
if cmd == nil {
t.Fatal("ConfigCmd returned nil")
}
// PreRunE should be nil when passed nil
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
})
}
}

203
cmd/config/opencode.go Normal file
View File

@@ -0,0 +1,203 @@
package config
import (
"encoding/json"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
}
return paths
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": "http://localhost:11434/v1",
},
}
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if displayName, ok := cfgMap["name"].(string); ok {
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
delete(models, name)
}
}
}
}
for _, model := range modelList {
models[model] = map[string]any{
"name": fmt.Sprintf("%s [Ollama]", model),
}
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
}

437
cmd/config/opencode_test.go Normal file
View File

@@ -0,0 +1,437 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
cleanup := func() {
os.RemoveAll(configDir)
os.RemoveAll(stateDir)
}
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
if provider["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve other models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("update existing model", func(t *testing.T) {
cleanup()
o.Edit([]string{"llama3.2"})
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["keybindings"] == nil {
t.Error("keybindings was removed")
}
})
t.Run("model state - insert at index 0", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
})
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
if len(state["favorite"].([]any)) != 1 {
t.Error("favorite was modified")
}
if state["variant"].(map[string]any)["a"] != "b" {
t.Error("variant was modified")
}
})
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
recent := state["recent"].([]any)
if len(recent) != 2 {
t.Errorf("expected 2 recent entries, got %d", len(recent))
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("remove model", func(t *testing.T) {
cleanup()
// First add two models
o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral")
// Then remove one by only selecting the other
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Add a non-Ollama model manually
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
})
}
func assertOpenCodeModelExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatal("provider not found")
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
t.Fatal("ollama provider not found")
}
models, ok := ollama["models"].(map[string]any)
if !ok {
t.Fatal("models not found")
}
if models[model] == nil {
t.Errorf("model %s not found", model)
}
}
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
return // No provider means no model
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
return // No ollama means no model
}
models, ok := ollama["models"].(map[string]any)
if !ok {
return // No models means no model
}
if models[model] != nil {
t.Errorf("model %s should not exist but was found", model)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Fatal(err)
}
recent, ok := state["recent"].([]any)
if !ok {
t.Fatal("recent not found")
}
if index >= len(recent) {
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
}
entry, ok := recent[index].(map[string]any)
if !ok {
t.Fatal("entry is not a map")
}
if entry["providerID"] != providerID {
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
}
if entry["modelID"] != modelID {
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
}
}
// Edge case tests for opencode.go
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
// Should not panic - corrupted JSON should be treated as empty
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted config: %v", err)
}
// Verify valid JSON was created
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Errorf("resulting config is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted state: %v", err)
}
// Verify valid state was created
data, _ := os.ReadFile(statePath)
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("resulting state is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type provider failed: %v", err)
}
// Verify provider is now correct type
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
}
if provider["ollama"] == nil {
t.Error("ollama provider should be created")
}
}
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type recent failed: %v", err)
}
// The function should handle this gracefully
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
// recent should be properly set after setup
recent, ok := state["recent"].([]any)
if !ok {
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
} else if len(recent) == 0 {
t.Logf("Note: recent is empty (documenting behavior)")
}
}
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
os.WriteFile(configPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := o.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(configPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Model name with special characters (though unusual)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
// Verify it was stored correctly
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
// Model should be accessible
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := o.Models()
if len(models) > 0 {
t.Errorf("expected nil/empty for missing config, got %v", models)
}
}

499
cmd/config/selector.go Normal file
View File

@@ -0,0 +1,499 @@
package config
import (
"errors"
"fmt"
"io"
"os"
"strings"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiHideCursor = "\033[?25l"
ansiShowCursor = "\033[?25h"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiClearDown = "\033[J"
)
const maxDisplayedItems = 10
var errCancelled = errors.New("cancelled")
type selectItem struct {
Name string
Description string
}
type inputEvent int
const (
eventNone inputEvent = iota
eventEnter
eventEscape
eventUp
eventDown
eventTab
eventBackspace
eventChar
)
type selectState struct {
items []selectItem
filter string
selected int
scrollOffset int
}
func newSelectState(items []selectItem) *selectState {
return &selectState{items: items}
}
func (s *selectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.selected = 0
s.scrollOffset = 0
}
case eventUp:
if s.selected > 0 {
s.selected--
if s.selected < s.scrollOffset {
s.scrollOffset = s.selected
}
}
case eventDown:
if s.selected < len(filtered)-1 {
s.selected++
if s.selected >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.selected - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.selected = 0
s.scrollOffset = 0
}
return false, "", nil
}
type multiSelectState struct {
items []selectItem
itemIndex map[string]int
filter string
highlighted int
scrollOffset int
checked map[int]bool
checkOrder []int
focusOnButton bool
}
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
s := &multiSelectState{
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
s.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := s.itemIndex[name]; ok {
s.checked[idx] = true
s.checkOrder = append(s.checkOrder, idx)
}
}
return s
}
func (s *multiSelectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *multiSelectState) toggleItem() {
filtered := s.filtered()
if len(filtered) == 0 || s.highlighted >= len(filtered) {
return
}
item := filtered[s.highlighted]
origIdx := s.itemIndex[item.Name]
if s.checked[origIdx] {
delete(s.checked, origIdx)
for i, idx := range s.checkOrder {
if idx == origIdx {
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
break
}
}
} else {
s.checked[origIdx] = true
s.checkOrder = append(s.checkOrder, origIdx)
}
}
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if s.focusOnButton && len(s.checkOrder) > 0 {
var res []string
for _, idx := range s.checkOrder {
res = append(res, s.items[idx].Name)
}
return true, res, nil
} else if !s.focusOnButton {
s.toggleItem()
}
case eventTab:
if len(s.checkOrder) > 0 {
s.focusOnButton = !s.focusOnButton
}
case eventEscape:
return true, nil, errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
case eventUp:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted > 0 {
s.highlighted--
if s.highlighted < s.scrollOffset {
s.scrollOffset = s.highlighted
}
}
case eventDown:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted < len(filtered)-1 {
s.highlighted++
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
return false, nil, nil
}
func (s *multiSelectState) selectedCount() int {
return len(s.checkOrder)
}
// Terminal I/O handling
type terminalState struct {
fd int
oldState *term.State
}
func enterRawMode() (*terminalState, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return nil, err
}
fmt.Fprint(os.Stderr, ansiHideCursor)
return &terminalState{fd: fd, oldState: oldState}, nil
}
func (t *terminalState) restore() {
fmt.Fprint(os.Stderr, ansiShowCursor)
term.Restore(t.fd, t.oldState)
}
func clearLines(n int) {
if n > 0 {
fmt.Fprintf(os.Stderr, "\033[%dA", n)
fmt.Fprint(os.Stderr, ansiClearDown)
}
}
func parseInput(r io.Reader) (inputEvent, byte, error) {
buf := make([]byte, 3)
n, err := r.Read(buf)
if err != nil {
return 0, 0, err
}
switch {
case n == 1 && buf[0] == 13:
return eventEnter, 0, nil
case n == 1 && (buf[0] == 3 || buf[0] == 27):
return eventEscape, 0, nil
case n == 1 && buf[0] == 9:
return eventTab, 0, nil
case n == 1 && buf[0] == 127:
return eventBackspace, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
return eventUp, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
return eventDown, 0, nil
case n == 1 && buf[0] >= 32 && buf[0] < 127:
return eventChar, buf[0], nil
}
return eventNone, 0, nil
}
// Rendering
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
prefix := " "
if idx == s.selected {
prefix = " " + ansiBold + "> "
}
if item.Description != "" {
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
} else {
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
return lineCount
}
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := s.itemIndex[item.Name]
checkbox := "[ ]"
if s.checked[origIdx] {
checkbox = "[x]"
}
prefix := " "
suffix := ""
if idx == s.highlighted && !s.focusOnButton {
prefix = "> "
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
suffix = " " + ansiGray + "(default)" + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
fmt.Fprintf(w, "\r\n")
lineCount++
count := s.selectedCount()
switch {
case count == 0:
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
case s.focusOnButton:
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
default:
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
}
lineCount++
return lineCount
}
// selectPrompt prompts the user to select a single item from a list.
func selectPrompt(prompt string, items []selectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return "", err
}
defer ts.restore()
state := newSelectState(items)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return "", err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return "", err
}
return result, nil
}
render()
}
}
// multiSelectPrompt prompts the user to select multiple items from a list.
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return nil, err
}
defer ts.restore()
state := newMultiSelectState(items, preChecked)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return nil, err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return nil, err
}
return result, nil
}
render()
}
}
func confirmPrompt(prompt string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}
func filterItems(items []selectItem, filter string) []selectItem {
if filter == "" {
return items
}
var result []selectItem
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}

913
cmd/config/selector_test.go Normal file
View File

@@ -0,0 +1,913 @@
package config
import (
"bytes"
"strings"
"testing"
)
func TestFilterItems(t *testing.T) {
items := []selectItem{
{Name: "llama3.2:latest"},
{Name: "qwen2.5:7b"},
{Name: "deepseek-v3:cloud"},
{Name: "GPT-OSS:20b"},
}
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
result := filterItems(items, "")
if len(result) != len(items) {
t.Errorf("expected %d items, got %d", len(items), len(result))
}
})
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
result := filterItems(items, "LLAMA")
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
t.Errorf("expected llama3.2:latest, got %v", result)
}
})
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
result := filterItems(items, "gpt")
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
t.Errorf("expected GPT-OSS:20b, got %v", result)
}
})
t.Run("PartialMatch", func(t *testing.T) {
result := filterItems(items, "deep")
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
t.Errorf("expected deepseek-v3:cloud, got %v", result)
}
})
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
result := filterItems(items, "nonexistent")
if len(result) != 0 {
t.Errorf("expected 0 items, got %d", len(result))
}
})
}
func TestSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState", func(t *testing.T) {
s := newSelectState(items)
if s.selected != 0 {
t.Errorf("expected selected=0, got %d", s.selected)
}
if s.filter != "" {
t.Errorf("expected empty filter, got %q", s.filter)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item1" || err != nil {
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
s := newSelectState(items)
s.filter = "item3"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item3" || err != nil {
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != "" || err != errCancelled {
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Down_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventDown, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventDown, 0)
if s.selected != 2 {
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
}
})
t.Run("Up_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventUp, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventChar, 'i')
s.handleInput(eventChar, 't')
s.handleInput(eventChar, 'e')
s.handleInput(eventChar, 'm')
s.handleInput(eventChar, '2')
if s.filter != "item2" {
t.Errorf("expected filter='item2', got %q", s.filter)
}
filtered := s.filtered()
if len(filtered) != 1 || filtered[0].Name != "item2" {
t.Errorf("expected [item2], got %v", filtered)
}
})
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventChar, 'x')
if s.selected != 0 {
t.Errorf("expected selected=0 after typing, got %d", s.selected)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventBackspace, 0)
if s.filter != "" {
t.Errorf("expected filter='', got %q", s.filter)
}
})
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.selected = 2
s.handleInput(eventBackspace, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
}
})
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
// maxDisplayedItems is 10, so with 15 items we need to scroll
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
// move down 12 times (past the 10-item viewport)
for range 12 {
s.handleInput(eventDown, 0)
}
if s.selected != 12 {
t.Errorf("expected selected=12, got %d", s.selected)
}
if s.scrollOffset != 3 {
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
}
})
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
s.selected = 5
s.scrollOffset = 5
s.handleInput(eventUp, 0)
if s.selected != 4 {
t.Errorf("expected selected=4, got %d", s.selected)
}
if s.scrollOffset != 4 {
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
}
})
}
func TestMultiSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected, got %d", s.selectedCount())
}
if s.focusOnButton {
t.Error("expected focusOnButton=false initially")
}
})
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item3"})
if s.selectedCount() != 2 {
t.Errorf("expected 2 selected, got %d", s.selectedCount())
}
if !s.checked[1] || !s.checked[2] {
t.Error("expected item2 and item3 to be checked")
}
})
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
// order matters: first checked = default model
s := newMultiSelectState(items, []string{"item3", "item1"})
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
}
})
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
if s.selectedCount() != 1 {
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
}
})
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.toggleItem()
if !s.checked[0] {
t.Error("expected item1 to be checked after toggle")
}
})
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.toggleItem()
if s.checked[0] {
t.Error("expected item1 to be unchecked after toggle")
}
})
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
s.highlighted = 1 // toggle item2
s.toggleItem()
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
// should be [0, 2] (item1, item3) with item2 removed
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
}
})
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventEnter, 0)
if !s.checked[0] {
t.Error("expected item1 to be checked after enter")
}
})
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if !done || err != nil {
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
}
// result should preserve selection order
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
t.Errorf("expected [item2, item1], got %v", result)
}
})
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if done || result != nil || err != nil {
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after tab")
}
})
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("tab should not focus button when nothing selected")
}
})
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after first tab")
}
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("expected focus back on list after second tab")
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != nil || err != errCancelled {
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
t.Error("expected item2 (idx 1) to be default (first checked)")
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected item1 (idx 0) to NOT be default")
}
})
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected isDefault=false when nothing checked")
}
})
t.Run("Down_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventDown, 0)
if s.highlighted != 1 {
t.Errorf("expected highlighted=1, got %d", s.highlighted)
}
})
t.Run("Up_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.highlighted = 1
s.handleInput(eventUp, 0)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
})
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.focusOnButton = true
s.handleInput(eventDown, 0)
if s.focusOnButton {
t.Error("expected focus to return to list on arrow key")
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventChar, 'x')
if s.filter != "x" {
t.Errorf("expected filter='x', got %q", s.filter)
}
})
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newMultiSelectState(manyItems, nil)
s.highlighted = 10
s.scrollOffset = 5
s.handleInput(eventChar, 'x')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.filter = "x"
s.focusOnButton = true
s.handleInput(eventBackspace, 0)
if s.focusOnButton {
t.Error("expected focusOnButton=false after backspace")
}
})
}
func TestParseInput(t *testing.T) {
t.Run("Enter", func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{13}))
if err != nil || event != eventEnter || char != 0 {
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
}
})
t.Run("Escape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape, got %v", event)
}
})
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{3}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
}
})
t.Run("Tab", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{9}))
if err != nil || event != eventTab {
t.Errorf("expected eventTab, got %v", event)
}
})
t.Run("Backspace", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{127}))
if err != nil || event != eventBackspace {
t.Errorf("expected eventBackspace, got %v", event)
}
})
t.Run("UpArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
if err != nil || event != eventUp {
t.Errorf("expected eventUp, got %v", event)
}
})
t.Run("DownArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
if err != nil || event != eventDown {
t.Errorf("expected eventDown, got %v", event)
}
})
t.Run("PrintableChars", func(t *testing.T) {
tests := []struct {
name string
char byte
}{
{"lowercase", 'a'},
{"uppercase", 'Z'},
{"digit", '5'},
{"space", ' '},
{"tilde", '~'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
if err != nil || event != eventChar || char != tt.char {
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
}
})
}
})
}
func TestRenderSelect(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "first item"},
{Name: "item2"},
}
t.Run("ShowsPromptAndItems", func(t *testing.T) {
s := newSelectState(items)
var buf bytes.Buffer
lineCount := renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Select:") {
t.Error("expected prompt in output")
}
if !strings.Contains(output, "item1") {
t.Error("expected item1 in output")
}
if !strings.Contains(output, "first item") {
t.Error("expected description in output")
}
if !strings.Contains(output, "item2") {
t.Error("expected item2 in output")
}
if lineCount != 3 { // 1 prompt + 2 items
t.Errorf("expected 3 lines, got %d", lineCount)
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
}
})
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
// 15 items - 10 displayed = 5 more
if !strings.Contains(buf.String(), "5 more") {
t.Error("expected '5 more' indicator")
}
})
}
func TestRenderMultiSelect(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
}
t.Run("ShowsCheckboxes", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "[x]") {
t.Error("expected checked checkbox [x]")
}
if !strings.Contains(output, "[ ]") {
t.Error("expected unchecked checkbox [ ]")
}
})
t.Run("ShowsDefaultMarker", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "(default)") {
t.Error("expected (default) marker for first checked item")
}
})
t.Run("ShowsSelectedCount", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "2 selected") {
t.Error("expected '2 selected' in output")
}
})
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
s := newMultiSelectState(items, nil)
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "Select at least one") {
t.Error("expected 'Select at least one' helper text")
}
})
}
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
// Edge case tests for selector.go
// TestSelectState_SingleItem verifies that single item list works without crash.
// List with only one item should still work.
func TestSelectState_SingleItem(t *testing.T) {
items := []selectItem{{Name: "only-one"}}
s := newSelectState(items)
// Down should do nothing (already at bottom)
s.handleInput(eventDown, 0)
if s.selected != 0 {
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
}
// Up should do nothing (already at top)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
}
// Enter should select the only item
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "only-one" || err != nil {
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
}
}
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
// List with exactly maxDisplayedItems items should not scroll.
func TestSelectState_ExactlyMaxItems(t *testing.T) {
items := make([]selectItem, maxDisplayedItems)
for i := range items {
items[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(items)
// Move to last item
for range maxDisplayedItems - 1 {
s.handleInput(eventDown, 0)
}
if s.selected != maxDisplayedItems-1 {
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
// Should not scroll when exactly at max
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
}
// One more down should do nothing
s.handleInput(eventDown, 0)
if s.selected != maxDisplayedItems-1 {
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
}
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
// User typing "model.v1" shouldn't match "modelsv1".
func TestFilterItems_RegexSpecialChars(t *testing.T) {
items := []selectItem{
{Name: "model.v1"},
{Name: "modelsv1"},
{Name: "model-v1"},
}
// Filter with dot should only match literal dot
result := filterItems(items, "model.v1")
if len(result) != 1 {
t.Errorf("expected 1 exact match, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "model.v1" {
t.Errorf("expected 'model.v1', got %s", result[0].Name)
}
// Other regex special chars should be literal too
items2 := []selectItem{
{Name: "test[0]"},
{Name: "test0"},
{Name: "test(1)"},
}
result2 := filterItems(items2, "test[0]")
if len(result2) != 1 || result2[0].Name != "test[0]" {
t.Errorf("expected only 'test[0]', got %v", result2)
}
}
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
// itemIndex uses name as key - duplicates cause collision. This documents
// the current behavior: the last index for a duplicate name is stored
func TestMultiSelectState_DuplicateNames(t *testing.T) {
// Duplicate names - this is an edge case that shouldn't happen in practice
items := []selectItem{
{Name: "duplicate"},
{Name: "duplicate"},
{Name: "unique"},
}
s := newMultiSelectState(items, nil)
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
// When there are duplicates, only the last occurrence's index is stored
if s.itemIndex["duplicate"] != 1 {
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
}
// Toggle item at highlighted=0 (first "duplicate")
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
// So it actually toggles the SECOND duplicate item, not the first
s.toggleItem()
// This documents the potentially surprising behavior:
// We toggled at highlighted=0, but itemIndex lookup returned 1
if !s.checked[1] {
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
}
if s.checked[0] {
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
}
}
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
// Prevents index-out-of-bounds on next keystroke
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newSelectState(items)
s.selected = 2 // Select "cherry"
// Type a filter that removes cherry from results
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
// Selection should reset to 0
if s.selected != 0 {
t.Errorf("expected selected=0 after filter, got %d", s.selected)
}
filtered := s.filtered()
if len(filtered) != 2 {
t.Errorf("expected 2 filtered items, got %d", len(filtered))
}
}
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
// Model names might contain unicode characters
func TestFilterItems_UnicodeCharacters(t *testing.T) {
items := []selectItem{
{Name: "llama-日本語"},
{Name: "模型-chinese"},
{Name: "émoji-🦙"},
{Name: "regular-model"},
}
t.Run("filter japanese", func(t *testing.T) {
result := filterItems(items, "日本")
if len(result) != 1 || result[0].Name != "llama-日本語" {
t.Errorf("expected llama-日本語, got %v", result)
}
})
t.Run("filter chinese", func(t *testing.T) {
result := filterItems(items, "模型")
if len(result) != 1 || result[0].Name != "模型-chinese" {
t.Errorf("expected 模型-chinese, got %v", result)
}
})
t.Run("filter emoji", func(t *testing.T) {
result := filterItems(items, "🦙")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
t.Run("filter accented char", func(t *testing.T) {
result := filterItems(items, "émoji")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
}
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newMultiSelectState(items, nil)
s.highlighted = 2 // Highlight "cherry"
// Type a filter that removes cherry
s.handleInput(eventChar, 'a')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
}
}
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
// Empty list should be handled gracefully.
func TestMultiSelectState_EmptyItems(t *testing.T) {
s := newMultiSelectState([]selectItem{}, nil)
// Toggle should not panic on empty list
s.toggleItem()
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
}
// Render should handle empty list
var buf bytes.Buffer
lineCount := renderMultiSelect(&buf, "Select:", s)
if lineCount == 0 {
t.Error("renderMultiSelect should produce output even for empty list")
}
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' for empty list")
}
}
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
func TestSelectState_RenderWithDescriptions(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "First item description"},
{Name: "item2", Description: ""},
{Name: "item3", Description: "Third item"},
}
s := newSelectState(items)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "First item description") {
t.Error("expected description to be rendered")
}
if !strings.Contains(output, "item2") {
t.Error("expected item without description to be rendered")
}
}

View File

@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
scanner.Prompt.UseAlt = true
continue
}

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"crypto/sha256"
@@ -14,7 +14,7 @@ type Layer struct {
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
Status string `json:"-"`
}
const (
@@ -22,7 +22,7 @@ const (
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
blobs, err := BlobsPath("")
if err != nil {
return Layer{}, err
}
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
MediaType: mediatype,
Digest: digest,
Size: n,
status: fmt.Sprintf("%s %s", status, digest),
Status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest)
blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
status: fmt.Sprintf("using existing layer %s", digest),
Status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest)
blob, err := BlobsPath(l.Digest)
if err != nil {
return nil, err
}
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
}
}
blob, err := GetBlobsPath(l.Digest)
blob, err := BlobsPath(l.Digest)
if err != nil {
return err
}

View File

@@ -1,10 +1,9 @@
package server
package manifest
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
return
}
func (m *Manifest) Digest() string {
return m.digest
}
func (m *Manifest) FileInfo() os.FileInfo {
return m.fi
}
// ReadConfigJSON reads and unmarshals a config layer as JSON.
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
blobPath, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
data, err := os.ReadFile(blobPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
}
return fmt.Errorf("config %q not found in manifest", configPath)
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return err
}
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := GetBlobsPath(layer.Digest)
blob, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return nil, err
}
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return err
}
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
manifests, err := Path()
if err != nil {
return nil, err
}

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"encoding/json"

95
manifest/paths.go Normal file
View File

@@ -0,0 +1,95 @@
package manifest
import (
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
var ErrInvalidDigestFormat = errors.New("invalid digest format")
func Path() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PathForName returns the path to the manifest file for a specific model name.
func PathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := Path()
if err != nil {
return "", err
}
return filepath.Join(manifests, n.Filepath()), nil
}
func BlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PruneDirectory removes empty directories recursively.
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}

View File

@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
c.Next()
}
}
func ImageEditsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageEditRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.Image == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
return
}
genReq, err := openai.FromImageEditRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}
func TestImageEditsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
// Base64-encoded test image (1x1 pixel PNG)
testImage := ""
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
testCases := []testCase{
{
name: "image edit basic",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
},
},
{
name: "image edit with size",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
Width: 512,
Height: 768,
},
},
{
name: "image edit missing prompt",
body: `{
"model": "test-model",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing model",
body: `{
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing image",
body: `{
"model": "test-model",
"prompt": "make it blue"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "image is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}

View File

@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
Data: data,
}
}
// ImageEditRequest is an OpenAI-compatible image edit request.
type ImageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image string `json:"image"` // Base64-encoded image data
Size string `json:"size,omitempty"` // e.g., "1024x1024"
Seed *int64 `json:"seed,omitempty"`
}
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Decode the input image
if r.Image != "" {
imgData, err := decodeImageURL(r.Image)
if err != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
}
req.Images = append(req.Images, imgData)
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
if req.Options == nil {
req.Options = map[string]any{}
}
req.Options["seed"] = *r.Seed
}
return req, nil
}

View File

@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
func TestFromImageEditRequest_Basic(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if result.Prompt != "make it blue" {
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
}
if len(result.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Images))
}
}
func TestFromImageEditRequest_WithSize(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Size: "512x768",
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Width != 512 {
t.Errorf("expected width 512, got %d", result.Width)
}
if result.Height != 768 {
t.Errorf("expected height 768, got %d", result.Height)
}
}
func TestFromImageEditRequest_WithSeed(t *testing.T) {
seed := int64(12345)
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Seed: &seed,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options == nil {
t.Fatal("expected options to be set")
}
if result.Options["seed"] != seed {
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
}
}
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: "not-valid-base64",
}
_, err := FromImageEditRequest(req)
if err == nil {
t.Error("expected error for invalid image")
}
}

View File

@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
var currentLineBuf []rune
// draining tracks if we're processing buffered input from cooked mode.
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
// We treat \n from cooked mode as submit, not multiline.
// We check Buffered() after the first read since the bufio buffer is
// empty until then. This is compatible with """ multiline mode in
// interactive.go since each Readline() call is independent.
var draining, stopDraining bool
for {
// Apply deferred state change from previous iteration
if stopDraining {
draining = false
stopDraining = false
}
// don't show placeholder when pasting unless we're in multiline mode
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
if buf.IsEmpty() && showPlaceholder {
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
r, err := i.Terminal.Read()
// After reading, check if there's more buffered data. If so, we're
// processing cooked-mode input. Once buffer empties, the current
// char is the last buffered one (still drain it), then stop next iteration.
if i.Terminal.reader.Buffered() > 0 {
draining = true
} else if draining {
stopDraining = true
}
if buf.IsEmpty() {
fmt.Print(ClearToEOL)
}
@@ -232,15 +255,20 @@ func (i *Instance) Readline() (string, error) {
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
// If not draining cooked-mode input, treat as multiline
if !draining {
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
}
// Draining cooked-mode input: treat \n as submit
fallthrough
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -12,18 +12,18 @@ func Execute(args []string) error {
}
var newRunner bool
var imageRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
imageRunner = true
mlxRunner = true
}
if imageRunner {
return imagerunner.Execute(args)
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- resp
}
oldManifest, _ := ParseNamedManifest(name)
oldManifest, _ := manifest.ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := GetBlobsPath(files[fn])
blobPath, err := manifest.BlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
layer, err := NewLayer(t, mediaType)
layer, err := manifest.NewLayer(t, mediaType)
if err != nil {
return nil, err
}
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
var layers []manifest.Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
for _, layer := range layers {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
if layer.Status != "" {
fn(api.ProgressResponse{Status: layer.Status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, *configLayer, layers); err != nil {
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
return err
}
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
blob, err := GetBlobsPath(layer.Digest)
blob, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := NewLayer(temp, layer.MediaType)
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return detectChatTemplate(layers)
}
func removeLayer(layers []Layer, mediatype string) []Layer {
return slices.DeleteFunc(layers, func(layer Layer) bool {
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
})
}
func setTemplate(layers []Layer, t string) ([]Layer, error) {
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
}
blob := strings.NewReader(t)
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
return layers, nil
}
func setSystem(layers []Layer, s string) ([]Layer, error) {
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
return layers, nil
}
func setLicense(layers []Layer, l string) ([]Layer, error) {
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
blob := strings.NewReader(l)
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
return layers, nil
}
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
if p == nil {
p = make(map[string]any)
}
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
continue
}
digestPath, err := GetBlobsPath(layer.Digest)
digestPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
return layers, nil
}
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}

View File

@@ -24,6 +24,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
type downloadOpts struct {
mp ModelPath
n model.Name
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
@@ -465,10 +467,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
}
fp, err := GetBlobsPath(opts.digest)
fp, err := manifest.BlobsPath(opts.digest)
if err != nil {
return false, err
}
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
requestURL := opts.n.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -24,6 +23,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
@@ -31,6 +31,7 @@ import (
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
xserver "github.com/ollama/ollama/x/server"
)
var (
@@ -75,12 +76,6 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImage}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
if err == nil {
@@ -135,6 +130,14 @@ func (m *Model) Capabilities() []model.Capability {
return capabilities
}
// Check for thinking capability in safetensors LLM models based on architecture
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if xserver.IsSafetensorsThinkingModel(model.ParseName(m.Name)) {
capabilities = append(capabilities, model.CapabilityThinking)
return capabilities
}
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
hasTags := openingTag != "" && closingTag != ""
@@ -274,44 +277,22 @@ func (m *Model) String() string {
return modelfile.String()
}
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
f, err := os.Open(fp)
if err != nil {
return nil, "", err
}
defer f.Close()
sha256sum := sha256.New()
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
return nil, "", err
}
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
n := model.ParseName(name)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
m := &Model{
Name: n.String(),
ShortName: n.DisplayShortest(),
Digest: mf.Digest(),
Template: template.DefaultTemplate,
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if mf.Config.Digest != "" {
filename, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
return nil, err
}
@@ -322,29 +303,29 @@ func GetModel(name string) (*Model, error) {
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err
}
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
for _, layer := range mf.Layers {
filename, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.ParentModel = layer.From
m.ModelPath = filename
m.ParentModel = layer.From
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
@@ -352,7 +333,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.Template, err = template.Parse(string(bts))
m.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
@@ -362,7 +343,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.System = string(bts)
m.System = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -371,7 +352,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
@@ -381,7 +362,7 @@ func GetModel(name string) (*Model, error) {
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
@@ -389,11 +370,11 @@ func GetModel(name string) (*Model, error) {
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
m.License = append(m.License, string(bts))
}
}
return model, nil
return m, nil
}
func CopyModel(src, dst model.Name) error {
@@ -408,7 +389,7 @@ func CopyModel(src, dst model.Name) error {
return nil
}
manifests, err := GetManifestPath()
manifests, err := manifest.Path()
if err != nil {
return err
}
@@ -437,7 +418,7 @@ func CopyModel(src, dst model.Name) error {
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
manifests, err := Manifests(true)
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
@@ -452,7 +433,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
fp, err := manifest.BlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
@@ -468,7 +449,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
p, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -483,9 +464,9 @@ func PruneLayers() error {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
_, err := manifest.BlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
@@ -510,63 +491,30 @@ func PruneLayers() error {
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
manifest, _, err := GetManifest(mp)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
manifestPath, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -574,7 +522,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -582,17 +530,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
@@ -611,44 +559,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
manifest, _, err := GetManifest(mp)
existingMf, err := manifest.ParseNamedManifest(n)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
} else {
for _, l := range manifest.Layers {
for _, l := range existingMf.Layers {
deleteMap[l.Digest] = struct{}{}
}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
mf, err := pullModelManifest(ctx, n, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -658,7 +606,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
n: n,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
@@ -677,7 +625,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
fp, err := GetBlobsPath(layer.Digest)
fp, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return err
}
@@ -692,16 +640,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
delete(deleteMap, mf.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -728,9 +676,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
func hasTensorLayers(layers []manifest.Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
if layer.MediaType == manifest.MediaTypeImageTensor {
return true
}
}
@@ -738,7 +686,7 @@ func hasTensorLayers(layers []Layer) bool {
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -747,12 +695,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
destDir, err := GetBlobsPath("")
destDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -784,7 +732,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Repository: n.DisplayNamespaceModel(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
@@ -795,12 +743,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -812,7 +760,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -822,12 +770,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
srcDir, err := GetBlobsPath("")
srcDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -864,13 +812,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
ManifestRef: n.Tag,
Repository: n.DisplayNamespaceModel(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
@@ -880,7 +828,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m Manifest
var m manifest.Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
@@ -1042,7 +990,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
fp, err := manifest.BlobsPath(digest)
if err != nil {
return err
}

View File

@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityImage},
},
{
name: "model with image and vision capability (image editing)",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image", "vision"},
},
},
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -20,19 +21,19 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
Layer
manifest.Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := ParseNamedManifest(name)
m, err := manifest.ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
m, err = ParseNamedManifest(name)
m, err = manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
blobpath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}

View File

@@ -1,146 +0,0 @@
package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrModelPathInvalid = errors.New("invalid model path")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}

View File

@@ -1,153 +0,0 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
tempDir := t.TempDir()
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(tempDir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", tempDir)
got, err := GetBlobsPath(tc.digest)
require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -39,6 +39,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
existing, err := manifest.Manifests(true)
if err != nil {
return zero, err
}
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
m, err := ParseNamedManifest(n)
m, err := manifest.ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, ErrModelPathInvalid
return nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
manifest, err := ParseNamedManifest(name)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -1147,7 +1148,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
ModifiedAt: mf.FileInfo().ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
@@ -1214,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1223,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1285,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := Manifests(true)
ms, err := manifest.Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1316,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Digest: m.Digest(),
ModifiedAt: m.FileInfo().ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1376,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1392,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := GetBlobsPath(ib)
p, err := manifest.BlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1410,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1428,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
layer, err := NewLayer(c.Request.Body, "")
layer, err := manifest.NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1603,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoint
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -1628,7 +1630,7 @@ func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
blobsDir, err := GetBlobsPath("")
blobsDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -1637,7 +1639,7 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
if _, err := Manifests(false); err != nil {
if _, err := manifest.Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
@@ -1645,12 +1647,12 @@ func Serve(ln net.Listener) error {
return err
}
manifestsPath, err := GetManifestPath()
manifestsPath, err := manifest.Path()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
}
}
@@ -2506,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
// Check streaming preference
isStreaming := req.Stream == nil || *req.Stream
contentType := "application/x-ndjson"
if !isStreaming {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
// Get seed from options if provided
var seed int64
@@ -2522,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
}
var images []llm.ImageData
for i, imgData := range req.Images {
images = append(images, llm.ImageData{ID: i, Data: imgData})
}
var streamStarted bool
var finalResponse api.GenerateResponse
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
@@ -2552,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if !isStreaming {
finalResponse = res
return
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
@@ -2561,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if !isStreaming {
c.JSON(http.StatusOK, finalResponse)
}
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
manifest, err := ParseNamedManifest(model.ParseName("child"))
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if manifest.Config.Digest == "" {
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
configPath, err := GetBlobsPath(manifest.Config.Digest)
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err)
}
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := WriteManifest(n, config, []Layer{config}); err != nil {
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
t.Fatal(err)
}

View File

@@ -19,7 +19,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func (mockRunner) Ping(_ context.Context) error { return nil }
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
return mock, nil
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
}
})
}
func TestGenerateWithImages(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("images passed to completion request", func(t *testing.T) {
testImage := []byte("test-image-data")
mock.CompletionResponse.Content = "Image processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Describe this image",
Images: []api.ImageData{testImage},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify images were passed to the completion request
if len(mock.CompletionRequest.Images) != 1 {
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
t.Errorf("image data mismatch in completion request")
}
if mock.CompletionRequest.Images[0].ID != 0 {
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
}
})
t.Run("multiple images passed to completion request", func(t *testing.T) {
testImage1 := []byte("test-image-1")
testImage2 := []byte("test-image-2")
mock.CompletionResponse.Content = "Images processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Compare these images",
Images: []api.ImageData{testImage1, testImage2},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify both images were passed
if len(mock.CompletionRequest.Images) != 2 {
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
t.Errorf("first image data mismatch")
}
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
t.Errorf("second image data mismatch")
}
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
t.Errorf("expected image IDs 0 and 1, got %d and %d",
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
}
})
t.Run("no images when none provided", func(t *testing.T) {
mock.CompletionResponse.Content = "No images"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify no images in completion request
if len(mock.CompletionRequest.Images) != 0 {
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
})
}
// TestImageGenerateStreamFalse tests that image generation respects stream=false
// and returns a single JSON response instead of streaming ndjson.
func TestImageGenerateStreamFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
t.Fatal(err)
}
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
Stream: &streamFalse,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
}
body := w.Body.String()
lines := strings.Split(strings.TrimSpace(body), "\n")
if len(lines) != 1 {
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
}
var resp api.GenerateResponse
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp.Image != "base64image" {
t.Errorf("expected image 'base64image', got %q", resp.Image)
}
if !resp.Done {
t.Errorf("expected done=true")
}
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,11 +563,20 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
server, err := mlxrunner.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -21,12 +21,14 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
Layer
manifest.Layer
Total int64
Completed atomic.Int64
@@ -51,7 +53,7 @@ const (
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := GetBlobsPath(b.Digest)
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
return err
}
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
requestURL.RawQuery = values.Encode()
}
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
p, err := GetBlobsPath(b.Digest)
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
b.err = err
return
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"path/filepath"
"strings"
)
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
const MissingPart = "!MISSING!"
const (
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultProtocolScheme = "https"
)
// DefaultName returns a name with the default values for the host, namespace,
// and tag parts. The model and digest parts are empty.
// tag, and protocol scheme parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
// - The default protocol scheme is ("https")
func DefaultName() Name {
return Name{
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
ProtocolScheme: defaultProtocolScheme,
}
}
@@ -87,10 +91,11 @@ func (k partKind) String() string {
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
type Name struct {
Host string
Namespace string
Model string
Tag string
Host string
Namespace string
Model string
Tag string
ProtocolScheme string
}
// ParseName parses and assembles a Name from a name string. The
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
}
scheme, host, ok := strings.Cut(s, "://")
if !ok {
if ok {
n.ProtocolScheme = scheme
} else {
host = scheme
}
n.Host = host
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, and tag parts of the two names,
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
return a
}
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
strings.EqualFold(n.Tag, o.Tag)
}
// BaseURL returns the base URL for the registry.
func (n Name) BaseURL() *url.URL {
return &url.URL{
Scheme: n.ProtocolScheme,
Host: n.Host,
}
}
// DisplayNamespaceModel returns the namespace and model joined by "/".
func (n Name) DisplayNamespaceModel() string {
var b strings.Builder
b.WriteString(n.Namespace)
b.WriteByte('/')
b.WriteString(n.Model)
return b.String()
}
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:

View File

@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
{
in: "scheme://host:port/namespace/model:tag",
want: Name{
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
ProtocolScheme: "scheme",
},
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
},

View File

@@ -11,9 +11,12 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
@@ -51,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
var parserName, rendererName string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
capabilities = append(capabilities, "thinking")
}
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -79,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, "", ""),
progressFn,
)
}
@@ -103,7 +116,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
layer, err := manifest.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
@@ -141,13 +154,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -169,7 +182,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -186,7 +199,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -202,18 +215,33 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// TODO: find a better way to detect image input support
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
caps := capabilities
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
if data, err := os.ReadFile(modelIndex); err == nil {
var cfg struct {
ClassName string `json:"_class_name"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
caps = append(caps, "vision")
}
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -221,15 +249,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
// Convert LayerInfo to manifest.Layer
manifestLayers := make([]manifest.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
manifestLayers = append(manifestLayers, manifest.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
@@ -243,19 +271,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
manifestLayers = append(manifestLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
return manifest.WriteManifest(name, configLayer, manifestLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
@@ -263,7 +291,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
@@ -271,7 +299,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
@@ -280,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
// Check architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
// Check the architecture list
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
return true
}
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
return true
}
}
}
return false
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}

View File

@@ -13,7 +13,10 @@ import (
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Supported quantization types:
// - "fp4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "fp8": affine 8-bit, group_size=32 (with qbiases)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
@@ -55,10 +58,13 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "fp8":
// affine mode: group_size=32, bits=8
// affine mode: group_size=32, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)

View File

@@ -262,9 +262,10 @@ func ShouldQuantize(name, component string) bool {
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// The quantize parameter specifies the quantization type (e.g., "fp4", "nvfp4", "fp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
@@ -280,8 +281,13 @@ func ShouldQuantizeTensor(name string, shape []int32) bool {
return false
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
// MLX quantization requires last dimension to be divisible by group size
// NVFP4 uses group_size=16, all other modes use group_size=32
groupSize := int32(32)
if strings.ToUpper(quantize) == "NVFP4" {
groupSize = 16
}
if shape[len(shape)-1]%groupSize != 0 {
return false
}
@@ -331,7 +337,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape, quantize) {
quantizeType = quantize
}
@@ -388,6 +394,22 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
// Create model_index.json with quantization info if quantizing
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal model_index.json: %w", err)
}
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
if err != nil {
return fmt.Errorf("failed to create model_index.json layer: %w", err)
}
layers = append(layers, indexLayer)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {

View File

@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
name string
tensor string
shape []int32
quantize string
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests
// FP8/FP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
// NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Supported quantization types: fp4, fp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp4", "fp8":
case "", "fp4", "fp8", "nvfp4":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8, nvfp4", quantize)
}
var layers []LayerInfo
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
quantizeType = quantize
}
@@ -213,10 +213,15 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
// MLX requires the last dimension to be divisible by the group size.
// NVFP4 uses group_size=16, all other modes use group_size=32.
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
groupSize := int32(32)
if strings.ToUpper(quantize) == "NVFP4" {
groupSize = 16
}
return shape[len(shape)-1]%groupSize == 0
}

View File

@@ -9,6 +9,7 @@ type Cache interface {
Offset() int
Len() int
State() []*mlx.Array
Reset()
}
type KVCache struct {
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// Reset clears the cache state for a new generation session
func (c *KVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
// Reset clears the cache state for a new generation session
func (c *RotatingKVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
c.idx = 0
}

View File

@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// Supports both single-stream and dual-stream architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {

View File

@@ -10,7 +10,10 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
// Image paths can be included in the prompt and will be extracted automatically.
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Get options from flags (with env var defaults)
opts := DefaultOptions()
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
// Extract any image paths from the prompt
prompt, images, err := extractFileData(prompt)
if err != nil {
return err
}
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
// Check if it's a file path, not a command
args := strings.Fields(line)
isFile := false
for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) {
isFile = true
break
}
}
if !isFile {
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
continue
}
}
// Extract any image paths from the input
prompt, images, err := extractFileData(line)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Prompt: prompt,
Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
return false
}
}
// extractFileNames finds image file paths in the input string.
func extractFileNames(input string) []string {
// Regex to match file paths with image extensions
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
}
// extractFileData extracts image data from file paths found in the input.
// Returns the cleaned prompt (with file paths removed) and the image data.
func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input)
var imgs []api.ImageData
for _, fp := range filePaths {
// Normalize shell escapes
nfp := strings.ReplaceAll(fp, "\\ ", " ")
nfp = strings.ReplaceAll(nfp, "\\(", "(")
nfp = strings.ReplaceAll(nfp, "\\)", ")")
nfp = strings.ReplaceAll(nfp, "%20", " ")
data, err := getImageData(nfp)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return "", nil, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return strings.TrimSpace(input), imgs, nil
}
// getImageData reads and validates image data from a file.
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}
// Re-read the full file
return os.ReadFile(filePath)
}

View File

@@ -19,10 +19,9 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -61,14 +60,11 @@ func main() {
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
@@ -166,60 +162,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
@@ -301,6 +243,8 @@ func load(modelPath string) (Model, error) {
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
case "glm4_moe_lite":
return glm4_moe_lite.Load(modelPath)
default:
return llama.Load(modelPath)
}

View File

@@ -7,6 +7,8 @@ import (
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
}
// DecodeImage decodes image bytes with EXIF orientation applied.
// Transparent images are composited onto a white background.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
return nil, err
}
img = flattenAlpha(img)
return applyOrientation(img, orientation), nil
}
// flattenAlpha composites an image onto a white background,
// removing any transparency. This is needed because image
// generation models don't handle alpha channels well.
func flattenAlpha(img image.Image) image.Image {
if _, ok := img.(*image.RGBA); !ok {
if _, ok := img.(*image.NRGBA); !ok {
// No alpha channel, return as-is
return img
}
}
bounds := img.Bounds()
dst := image.NewRGBA(bounds)
// Fill with white background
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
// Composite the image on top
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
return dst
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {

View File

@@ -116,6 +116,18 @@ func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
return layers
}
// GetAllTensorLayers returns all tensor layers without component filtering.
// Used for LLM models where tensors don't have a component prefix.
func (m *ModelManifest) GetAllTensorLayers() []ManifestLayer {
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
layers = append(layers, layer)
}
}
return layers
}
// GetConfigLayer returns the config layer for a given path.
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
for _, layer := range m.Manifest.Layers {
@@ -161,6 +173,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
return false
}
// TotalTensorSize returns the total size in bytes of all tensor layers.
func (m *ModelManifest) TotalTensorSize() int64 {
var total int64
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
total += layer.Size
}
}
return total
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string

View File

@@ -5,6 +5,37 @@ import (
"testing"
)
func TestTotalTensorSize(t *testing.T) {
m := &ModelManifest{
Manifest: &Manifest{
Layers: []ManifestLayer{
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
},
},
}
got := m.TotalTensorSize()
want := int64(6000)
if got != want {
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
}
}
func TestTotalTensorSizeEmpty(t *testing.T) {
m := &ModelManifest{
Manifest: &Manifest{
Layers: []ManifestLayer{},
},
}
if got := m.TotalTensorSize(); got != 0 {
t.Errorf("TotalTensorSize() = %d, want 0", got)
}
}
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
modelsDir := filepath.Join(t.TempDir(), "models")

View File

@@ -16,18 +16,9 @@ import (
"runtime"
)
// GB is a convenience constant for gigabytes.
const GB = 1024 * 1024 * 1024
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 20 * GB, // ~20GB for Flux
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
}
}
// CheckMemoryRequirements validates that there's enough memory for image generation.
// Returns nil if memory is sufficient, or an error if not.
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
required := EstimateVRAM(modelName)
if availableMemory < required {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
required/GB, availableMemory/GB)
}
return nil
}
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
return ""
}
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
className := DetectModelType(modelName)
if estimate, ok := modelVRAMEstimates[className]; ok {
return estimate
}
return 21 * GB
}
// DetectModelType reads model_index.json and returns the model type.
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.

View File

@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
}
}
func TestCheckMemoryRequirements(t *testing.T) {
tests := []struct {
name string
availableMemory uint64
wantErr bool
}{
{
name: "sufficient memory",
availableMemory: 32 * GB,
wantErr: false,
},
{
name: "exactly enough memory",
availableMemory: 21 * GB,
wantErr: false,
},
{
name: "insufficient memory",
availableMemory: 16 * GB,
wantErr: true,
},
{
name: "zero memory",
availableMemory: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a non-existent model name which will default to 21GB estimate
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
if (err != nil) != tt.wantErr {
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 20 * GB,
}
for name, expectedVRAM := range expected {
if actual, ok := modelVRAMEstimates[name]; !ok {
t.Errorf("Missing VRAM estimate for %s", name)
} else if actual != expectedVRAM {
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
}
}
}
func TestEstimateVRAMDefault(t *testing.T) {
// Non-existent model should return default 21GB
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
if vram != 21*GB {
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")

View File

@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Stack stacks arrays along a new axis (axis 0 by default)
func Stack(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)

View File

@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
})
}
// GenerateImageWithInputs implements runner.ImageEditModel interface.
// It generates an image conditioned on the provided input images for image editing.
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
InputImages: inputImages,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048

View File

@@ -0,0 +1,709 @@
//go:build mlx
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim)
}
// MLAAttention implements Multi-head Latent Attention
type MLAAttention struct {
// Low-rank query projections
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
// Low-rank KV projections (with shared rope component)
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
KVBProj nn.LinearLayer `weight:"self_attn.kv_b_proj"`
// Output projection
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
}
// Forward computes MLA attention output
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Query path: q_a_proj -> layernorm -> q_b_proj
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
// Split Q into nope and rope parts
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
// KV path: kv_a_proj_with_mqa -> split -> layernorm -> kv_b_proj
compressedKV := a.KVAProjWithMQA.Forward(x)
// Split into compressed_kv and k_pe (shared rope component)
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
// Apply layernorm and project KV
kvCompressed = a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kv := a.KVBProj.Forward(kvCompressed)
// Reshape KV: [B, L, num_heads * (qk_nope_head_dim + v_head_dim)]
kv = mlx.Reshape(kv, B, L, cfg.NumAttentionHeads, cfg.QKNopeHeadDim+cfg.VHeadDim)
kv = mlx.Transpose(kv, 0, 2, 1, 3)
// Split into k_nope and values
kNope := mlx.Slice(kv, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
values := mlx.Slice(kv, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim + cfg.VHeadDim})
// Apply RoPE to the rope parts only
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
// Repeat k_pe across all heads
kPE = mlx.Tile(kPE, []int32{1, cfg.NumAttentionHeads, 1, 1})
// Concatenate nope and rope parts
queries := mlx.Concatenate([]*mlx.Array{qNope, qPE}, 3)
keys := mlx.Concatenate([]*mlx.Array{kNope, kPE}, 3)
// Update KV cache
if c != nil {
keys, values = c.Update(keys, values, int(L))
}
// Scaled dot product attention
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
// Compute gate logits through linear layer (handles both quantized and non-quantized)
gates := g.Gate.Forward(x)
// Sigmoid scoring
scores := mlx.Sigmoid(gates)
origScores := scores
// Add correction bias if present
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
}
// Group-wise expert selection (simplified for n_group=1)
// Select top-k experts
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
shape := inds.Shape()
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
// Get scores for selected experts
scores = mlx.TakeAlongAxis(origScores, inds, -1)
// Normalize if configured
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
// Apply routing scaling factor
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
// Note: No weight tags - these are populated manually by stacking expert weights
type SwitchMLP struct {
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
topK := cfg.NumExpertsPerTok
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
// Flatten for gather_mm: [B*L, 1, 1, D]
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
// Flatten indices: [B, L, topK] -> [B*L, topK]
idxFlat := mlx.Reshape(indices, B*L, topK)
// Sort for efficient gather (when we have many tokens)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
// Reorder x based on sorted indices
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
// Expert computation using gather_mm
// gate: x @ gate_weight.T (indices are on the rhs/weight side)
gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
// up: x @ up_weight.T
up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
// SwiGLU activation
hidden := mlx.Mul(mlx.SiLU(gate), up)
// down: hidden @ down_weight.T
down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
// Unsort if we sorted
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
// Get expert indices and scores
inds, scores := m.Gate.Forward(x, cfg)
// Apply routed experts
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
// Add shared experts if present
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MLP with residual
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MoE with residual
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []Block `weight:"-"` // Loaded manually due to different block types
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead nn.LinearLayer `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
// loadExpertWeight loads an expert weight, dequantizing if necessary.
// GatherMM doesn't support quantized weights, so we must dequantize for MoE.
func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array {
w, _ := weights.GetTensor(path + ".weight")
if w == nil {
return nil
}
// Check if this is a quantized weight by looking for scales
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
scales, _ := weights.GetTensor(scalePath)
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Dequantize using the model's quantization parameters
groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization())
return mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
return w
}
// sanitizeExpertWeights stacks individual expert weights into a single tensor.
// For quantized models, expert weights are dequantized since GatherMM doesn't support quantized weights.
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
var gateWeights, upWeights, downWeights []*mlx.Array
for e := int32(0); e < numExperts; e++ {
gw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.gate_proj", prefix, e))
uw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.up_proj", prefix, e))
dw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.down_proj", prefix, e))
if gw != nil {
gateWeights = append(gateWeights, gw)
}
if uw != nil {
upWeights = append(upWeights, uw)
}
if dw != nil {
downWeights = append(downWeights, dw)
}
}
var stackedGate, stackedUp, stackedDown *mlx.Array
if len(gateWeights) > 0 {
stackedGate = mlx.Stack(gateWeights, 0)
}
if len(upWeights) > 0 {
stackedUp = mlx.Stack(upWeights, 0)
}
if len(downWeights) > 0 {
stackedDown = mlx.Stack(downWeights, 0)
}
return stackedGate, stackedUp, stackedDown
}
// Load loads a GLM4-MoE-Lite model from the given path
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
}
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: &SwitchMLP{
GateWeight: gateW,
UpWeight: upW,
DownWeight: downW,
},
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
// Load weights from manifest blobs
weights, err := imagegen.LoadAllWeightsFromManifest(manifest)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
// Debug: print quantization info and sample tensor names
fmt.Printf("GLM4: quantization=%q, num_tensors=%d\n", weights.Quantization(), len(weights.ListTensors()))
tensors := weights.ListTensors()
for i, name := range tensors {
if i < 20 { // Print first 20 tensor names
fmt.Printf(" tensor[%d]: %s\n", i, name)
}
}
if err := weights.Load(0); err != nil {
return nil, fmt.Errorf("load weight data: %w", err)
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
// Build tokenizer config with companion files for EOS/BOS token loading
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData, // Already loaded above, contains eos_token_id
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
}
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: &SwitchMLP{
GateWeight: gateW,
UpWeight: upW,
DownWeight: downW,
},
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return m.LMHead.Forward(h)
}
// Interface methods
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
// When think is true, the prompt ends with <think> to enable reasoning mode.
// When think is false, the prompt ends with </think> to skip reasoning.
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
if think {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
}
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
func (m *Model) NewRenderer() *Renderer {
return &Renderer{}
}
// NewParser returns a new Parser for extracting thinking and tool calls from output.
func (m *Model) NewParser() *Parser {
return &Parser{}
}

View File

@@ -0,0 +1,479 @@
//go:build mlx
package glm4_moe_lite
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type parserState int
const (
parserState_LookingForThinkingOpen parserState = iota
parserState_ThinkingStartedEatingWhitespace
parserState_CollectingThinking
parserState_ThinkingDoneEatingWhitespace
parserState_CollectingContent
parserState_ToolStartedEatingWhitespace
parserState_CollectingToolContent
)
const (
thinkingOpenTag = "<think>"
thinkingCloseTag = "</think>"
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type Parser struct {
state parserState
buffer strings.Builder
tools []api.Tool
}
// HasToolSupport returns true as GLM4 supports tool calling.
func (p *Parser) HasToolSupport() bool {
return true
}
// HasThinkingSupport returns true as GLM4 supports thinking mode.
func (p *Parser) HasThinkingSupport() bool {
return true
}
// Init initializes the parser with tools and thinking configuration.
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = parserState_CollectingThinking
}
return tools
}
type parserEvent interface {
isParserEvent()
}
type eventContent struct {
content string
}
func (eventContent) isParserEvent() {}
type eventRawToolCall struct {
raw string
}
func (eventRawToolCall) isParserEvent() {}
type eventThinkingContent struct {
content string
}
func (eventThinkingContent) isParserEvent() {}
// Add processes new output text and returns parsed content, thinking, and tool calls.
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case eventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case eventThinkingContent:
thinkingSb.WriteString(event.content)
case eventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Parser) parseEvents() []parserEvent {
var all []parserEvent
keepLooping := true
for keepLooping {
var events []parserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *Parser) eat() ([]parserEvent, bool) {
var events []parserEvent
switch p.state {
case parserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, thinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = parserState_ThinkingStartedEatingWhitespace
} else {
p.state = parserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = parserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case parserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
case parserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, eventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = parserState_ThinkingDoneEatingWhitespace
} else {
p.state = parserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
}
case parserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
case parserState_CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
before, after := p.splitAtTag(toolOpenTag, true)
if len(before) > 0 {
events = append(events, eventContent{content: before})
}
if after == "" {
p.state = parserState_ToolStartedEatingWhitespace
} else {
p.state = parserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
}
case parserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
case parserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, toolCloseTag) {
toolContent, _ := p.splitAtTag(toolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm4 tool call closing tag found but no content before it")
}
events = append(events, eventRawToolCall{raw: toolContent})
p.state = parserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// overlap returns the length of the overlap between the end of s and the start of tag.
func overlap(s, tag string) int {
for i := 1; i <= len(tag) && i <= len(s); i++ {
if strings.HasSuffix(s, tag[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length of trailing whitespace in s.
func trailingWhitespaceLen(s string) int {
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
return len(s) - len(trimmed)
}
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
type ToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeContent(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
escaped := escapeContent(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed ToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
func parseValue(value string, paramType api.PropertyType) any {
value = strings.TrimSpace(value)
// If no type specified, return as string
if len(paramType) == 0 {
return value
}
// Try to parse based on specified types
for _, t := range paramType {
switch t {
case "boolean":
if value == "true" {
return true
}
if value == "false" {
return false
}
case "integer":
var i int64
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
return i
}
case "number":
var f float64
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
return f
}
case "array", "object":
// Try to parse as JSON
var result any
if err := json.Unmarshal([]byte(value), &result); err == nil {
return result
}
}
}
// Default to string
return value
}

View File

@@ -0,0 +1,192 @@
//go:build mlx
package glm4_moe_lite
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestParserThinking(t *testing.T) {
tests := []struct {
name string
input string
thinkEnabled bool
wantContent string
wantThinking string
wantToolCalls int
}{
{
name: "thinking enabled - simple thinking then content",
input: "Let me think about this...</think>Here is my answer.",
thinkEnabled: true,
wantThinking: "Let me think about this...",
wantContent: "Here is my answer.",
},
{
name: "thinking enabled - only thinking",
input: "I need to consider multiple factors...",
thinkEnabled: true,
wantThinking: "I need to consider multiple factors...",
wantContent: "",
},
{
name: "thinking disabled - direct content",
input: "Here is my direct answer.",
thinkEnabled: false,
wantThinking: "",
wantContent: "Here is my direct answer.",
},
{
name: "thinking with tool call",
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
thinkEnabled: true,
wantThinking: "Let me search for that...",
wantContent: "I'll use a tool.",
wantToolCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Parser{}
var thinkValue *api.ThinkValue
if tt.thinkEnabled {
thinkValue = &api.ThinkValue{Value: true}
} else {
thinkValue = &api.ThinkValue{Value: false}
}
// Define tools for tool call tests
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "search",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
p.Init(tools, nil, thinkValue)
content, thinking, calls, err := p.Add(tt.input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if thinking != tt.wantThinking {
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
}
if content != tt.wantContent {
t.Errorf("content = %q, want %q", content, tt.wantContent)
}
if len(calls) != tt.wantToolCalls {
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
}
})
}
}
func TestParserToolCall(t *testing.T) {
p := &Parser{}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
// Initialize with thinking disabled
tv := &api.ThinkValue{Value: false}
p.Init(tools, nil, tv)
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
_, _, calls, err := p.Add(input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
call := calls[0]
if call.Function.Name != "get_weather" {
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
}
location, ok := call.Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Errorf("location = %v, want %q", location, "San Francisco")
}
unit, ok := call.Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Errorf("unit = %v, want %q", unit, "celsius")
}
}
func TestOverlap(t *testing.T) {
tests := []struct {
s string
tag string
want int
}{
{"hello<", "</think>", 1},
{"hello</", "</think>", 2},
{"hello</t", "</think>", 3},
{"hello</th", "</think>", 4},
{"hello</thi", "</think>", 5},
{"hello</thin", "</think>", 6},
{"hello</think", "</think>", 7},
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
{"hello", "</think>", 0},
{"", "</think>", 0},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
got := overlap(tt.s, tt.tag)
if got != tt.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
}
})
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
tests := []struct {
s string
want int
}{
{"hello ", 3},
{"hello\n\t ", 3},
{"hello", 0},
{"", 0},
{" ", 3},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
got := trailingWhitespaceLen(tt.s)
if got != tt.want {
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// Renderer renders messages for GLM4-MoE-Lite models.
//
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type Renderer struct{}
// Render renders messages into the GLM4 chat format.
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
// renderToolArguments converts tool call arguments to GLM4 XML format.
func renderToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
func formatToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,205 @@
//go:build mlx
package glm4_moe_lite
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestRendererSimple(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
// Thinking enabled (default)
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererThinkingDisabled(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
tv := &api.ThinkValue{Value: false}
result, err := r.Render(messages, nil, tv)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererMultiTurn(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
{Role: "user", Content: "And 3+3?"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check key parts
if !strings.Contains(result, "[gMASK]<sop>") {
t.Error("missing [gMASK]<sop> prefix")
}
if !strings.Contains(result, "<|user|>What is 2+2?") {
t.Error("missing first user message")
}
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
t.Error("missing assistant message with thinking")
}
if !strings.Contains(result, "<|user|>And 3+3?") {
t.Error("missing second user message")
}
if !strings.HasSuffix(result, "<|assistant|><think>") {
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
}
}
func TestRendererWithSystem(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
t.Error("missing system message")
}
}
func TestRendererWithTools(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What's the weather?"},
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"location"},
},
},
},
}
result, err := r.Render(messages, tools, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check for tool system prompt
if !strings.Contains(result, "<|system|>") {
t.Error("missing system tag for tools")
}
if !strings.Contains(result, "# Tools") {
t.Error("missing tools header")
}
if !strings.Contains(result, "<tools>") {
t.Error("missing tools tag")
}
if !strings.Contains(result, "get_weather") {
t.Error("missing tool name")
}
if !strings.Contains(result, "</tools>") {
t.Error("missing closing tools tag")
}
}
func TestRendererWithToolCalls(t *testing.T) {
r := &Renderer{}
args := api.NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
messages := []api.Message{
{Role: "user", Content: "What's the weather in SF?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<tool_call>get_weather") {
t.Error("missing tool call")
}
if !strings.Contains(result, "<arg_key>location</arg_key>") {
t.Error("missing arg_key")
}
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
t.Error("missing arg_value")
}
if !strings.Contains(result, "</tool_call>") {
t.Error("missing tool call closing tag")
}
if !strings.Contains(result, "<|observation|>") {
t.Error("missing observation tag")
}
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
t.Error("missing tool response")
}
}
func TestFormatToolJSON(t *testing.T) {
input := []byte(`{"name":"test","value":123}`)
result := formatToolJSON(input)
// Should add spaces after : and ,
if !strings.Contains(result, ": ") {
t.Error("should add space after colon")
}
if !strings.Contains(result, ", ") {
t.Error("should add space after comma")
}
}

View File

@@ -1,87 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,367 +0,0 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,218 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,135 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -1,174 +0,0 @@
//go:build mlx
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

View File

@@ -1,868 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Qwen-Image transformer configuration
type TransformerConfig struct {
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
NHeads int32 `json:"num_attention_heads"` // 24
HeadDim int32 `json:"attention_head_dim"` // 128
NLayers int32 `json:"num_layers"` // 60
InChannels int32 `json:"in_channels"` // 64
OutChannels int32 `json:"out_channels"` // 16
PatchSize int32 `json:"patch_size"` // 2
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
NormEps float32 `json:"norm_eps"` // 1e-6
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
GuidanceEmbeds bool `json:"guidance_embeds"` // false
}
// defaultTransformerConfig returns config for Qwen-Image transformer
func defaultTransformerConfig() *TransformerConfig {
return &TransformerConfig{
HiddenDim: 3072, // 24 * 128
NHeads: 24,
HeadDim: 128,
NLayers: 60,
InChannels: 64,
OutChannels: 16,
PatchSize: 2,
JointAttentionDim: 3584,
NormEps: 1e-6,
AxesDimsRope: []int32{16, 56, 56},
GuidanceEmbeds: false,
}
}
// TimestepEmbedder creates timestep embeddings
type TimestepEmbedder struct {
Linear1Weight *mlx.Array // [256, hidden_dim]
Linear1Bias *mlx.Array
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
Linear2Bias *mlx.Array
}
// newTimestepEmbedder creates a timestep embedder from weights
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
if err != nil {
return nil, err
}
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
if err != nil {
return nil, err
}
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
if err != nil {
return nil, err
}
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
if err != nil {
return nil, err
}
return &TimestepEmbedder{
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
Linear1Bias: linear1Bias,
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
Linear2Bias: linear2Bias,
}, nil
}
// Forward computes timestep embeddings
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
half := int32(128) // embedding_dim / 2
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
tExpanded := mlx.ExpandDims(t, 1)
args := mlx.Mul(tExpanded, freqsArr)
args = mlx.MulScalar(args, 1000.0) // scale
// [cos, sin] (flip_sin_to_cos=True)
sinArgs := mlx.Sin(args)
cosArgs := mlx.Cos(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
// MLP: linear1 -> silu -> linear2
h := mlx.Linear(embedding, te.Linear1Weight)
h = mlx.Add(h, te.Linear1Bias)
h = mlx.SiLU(h)
h = mlx.Linear(h, te.Linear2Weight)
h = mlx.Add(h, te.Linear2Bias)
return h
}
// JointAttention implements dual-stream joint attention
type JointAttention struct {
// Image projections
ToQ *mlx.Array
ToQB *mlx.Array
ToK *mlx.Array
ToKB *mlx.Array
ToV *mlx.Array
ToVB *mlx.Array
ToOut *mlx.Array
ToOutB *mlx.Array
NormQ *mlx.Array
NormK *mlx.Array
// Text (added) projections
AddQProj *mlx.Array
AddQProjB *mlx.Array
AddKProj *mlx.Array
AddKProjB *mlx.Array
AddVProj *mlx.Array
AddVProjB *mlx.Array
ToAddOut *mlx.Array
ToAddOutB *mlx.Array
NormAddQ *mlx.Array
NormAddK *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newJointAttention creates a joint attention layer
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
return &JointAttention{
ToQ: mlx.Transpose(toQ, 1, 0),
ToQB: toQB,
ToK: mlx.Transpose(toK, 1, 0),
ToKB: toKB,
ToV: mlx.Transpose(toV, 1, 0),
ToVB: toVB,
ToOut: mlx.Transpose(toOut, 1, 0),
ToOutB: toOutB,
NormQ: normQ,
NormK: normK,
AddQProj: mlx.Transpose(addQProj, 1, 0),
AddQProjB: addQProjB,
AddKProj: mlx.Transpose(addKProj, 1, 0),
AddKProjB: addKProjB,
AddVProj: mlx.Transpose(addVProj, 1, 0),
AddVProjB: addVProjB,
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
ToAddOutB: toAddOutB,
NormAddQ: normAddQ,
NormAddK: normAddK,
NHeads: cfg.NHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
}, nil
}
// Forward computes joint attention
// img: [B, L_img, D], txt: [B, L_txt, D]
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
D := imgShape[2]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// === Image Q/K/V ===
imgFlat := mlx.Reshape(img, B*Limg, D)
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
// QK norm (RMSNorm per head)
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
// Apply RoPE
if imgFreqs != nil {
qImg = applyRoPE(qImg, imgFreqs)
kImg = applyRoPE(kImg, imgFreqs)
}
// === Text Q/K/V ===
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
if txtFreqs != nil {
qTxt = applyRoPE(qTxt, txtFreqs)
kTxt = applyRoPE(kTxt, txtFreqs)
}
// Concatenate for joint attention: [txt, img] order
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
// Transpose to [B, nheads, L, head_dim]
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
// SDPA
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
// Transpose back and split
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
// Output projections
outImg = mlx.Reshape(outImg, B*Limg, D)
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
outImg = mlx.Reshape(outImg, B, Limg, D)
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
return outImg, outTxt
}
// applyRoPE applies rotary embeddings using complex multiplication
// x: [B, L, nheads, head_dim]
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
halfDim := headDim / 2
// Reshape x to pairs: [B, L, nheads, half, 2]
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
// Extract real/imag parts
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
xReal = mlx.Squeeze(xReal, 4)
xImag = mlx.Squeeze(xImag, 4)
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
freqReal = mlx.Squeeze(freqReal, 4)
freqImag = mlx.Squeeze(freqImag, 4)
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
// Interleave back
outReal = mlx.ExpandDims(outReal, 4)
outImag = mlx.ExpandDims(outImag, 4)
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
return mlx.Reshape(out, B, L, nheads, headDim)
}
// MLP implements GELU MLP (not GEGLU)
type MLP struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
OutWeight *mlx.Array
OutBias *mlx.Array
}
// newMLP creates a GELU MLP
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
outWeight, _ := weights.Get(prefix + ".net.2.weight")
outBias, _ := weights.Get(prefix + ".net.2.bias")
return &MLP{
ProjWeight: mlx.Transpose(projWeight, 1, 0),
ProjBias: projBias,
OutWeight: mlx.Transpose(outWeight, 1, 0),
OutBias: outBias,
}, nil
}
// Forward applies GELU MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
h = geluApprox(h)
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
}
// geluApprox implements approximate GELU
func geluApprox(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
inner = mlx.MulScalar(inner, sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// TransformerBlock is a single dual-stream transformer block
type TransformerBlock struct {
Attention *JointAttention
ImgMLP *MLP
TxtMLP *MLP
ImgModWeight *mlx.Array
ImgModBias *mlx.Array
TxtModWeight *mlx.Array
TxtModBias *mlx.Array
HiddenDim int32
NormEps float32
}
// newTransformerBlock creates a transformer block
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
attn, err := newJointAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
return &TransformerBlock{
Attention: attn,
ImgMLP: imgMLP,
TxtMLP: txtMLP,
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
ImgModBias: imgModBias,
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
TxtModBias: txtModBias,
HiddenDim: cfg.HiddenDim,
NormEps: cfg.NormEps,
}, nil
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
siluT := mlx.SiLU(temb)
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
imgModParts := splitMod6(imgMod, tb.HiddenDim)
txtModParts := splitMod6(txtMod, tb.HiddenDim)
// Pre-attention: norm + modulate
imgNorm := layerNormNoAffine(img, tb.NormEps)
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
txtNorm := layerNormNoAffine(txt, tb.NormEps)
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
// Joint attention
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
// Pre-MLP: norm + modulate
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
// MLP
mlpImg := tb.ImgMLP.Forward(imgNorm2)
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
return img, txt
}
// splitMod6 splits modulation into 6 parts each [B, 1, D]
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
shape := mod.Shape()
B := shape[0]
parts := make([]*mlx.Array, 6)
for i := int32(0); i < 6; i++ {
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
parts[i] = mlx.ExpandDims(part, 1)
}
return parts
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Qwen-Image transformer model
type Transformer struct {
Config *TransformerConfig
ImgIn *mlx.Array
ImgInBias *mlx.Array
TxtIn *mlx.Array
TxtInBias *mlx.Array
TxtNorm *mlx.Array
TEmbed *TimestepEmbedder
Layers []*TransformerBlock
NormOutWeight *mlx.Array
NormOutBias *mlx.Array
ProjOut *mlx.Array
ProjOutBias *mlx.Array
}
// Load loads the transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Qwen-Image transformer...")
cfg := defaultTransformerConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading input projections... ")
imgIn, _ := weights.Get("img_in.weight")
imgInBias, _ := weights.Get("img_in.bias")
txtIn, _ := weights.Get("txt_in.weight")
txtInBias, _ := weights.Get("txt_in.bias")
txtNorm, _ := weights.Get("txt_norm.weight")
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
m.ImgInBias = imgInBias
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
m.TxtInBias = txtInBias
m.TxtNorm = txtNorm
fmt.Println("✓")
fmt.Print(" Loading timestep embedder... ")
m.TEmbed, err = newTimestepEmbedder(weights)
if err != nil {
return fmt.Errorf("timestep embedder: %w", err)
}
fmt.Println("✓")
m.Layers = make([]*TransformerBlock, cfg.NLayers)
for i := int32(0); i < cfg.NLayers; i++ {
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
prefix := fmt.Sprintf("transformer_blocks.%d", i)
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
if err != nil {
return fmt.Errorf("layer %d: %w", i, err)
}
}
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
fmt.Print(" Loading output layers... ")
normOutWeight, _ := weights.Get("norm_out.linear.weight")
normOutBias, _ := weights.Get("norm_out.linear.bias")
projOut, _ := weights.Get("proj_out.weight")
projOutBias, _ := weights.Get("proj_out.bias")
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
m.NormOutBias = normOutBias
m.ProjOut = mlx.Transpose(projOut, 1, 0)
m.ProjOutBias = projOutBias
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadFromPath is a convenience function to load transformer from path
func LoadTransformerFromPath(path string) (*Transformer, error) {
m := &Transformer{}
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
return nil, err
}
return m, nil
}
// Forward runs the transformer
// img: [B, L_img, in_channels] patchified latents
// txt: [B, L_txt, joint_attention_dim] text embeddings
// t: [B] timesteps (0-1)
// imgFreqs, txtFreqs: RoPE frequencies
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
for _, layer := range tr.Layers {
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
}
// Final norm with modulation (AdaLayerNormContinuous)
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// ForwardWithCache runs the transformer with layer caching for speedup.
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between denoising steps, so we cache their
// outputs and reuse them on non-refresh steps.
//
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
// step: current denoising step (0-indexed)
// cacheInterval: refresh cache every N steps (e.g., 3)
// cacheLayers: number of shallow layers to cache (e.g., 15)
func (tr *Transformer) ForwardWithCache(
img, txt, t *mlx.Array,
imgFreqs, txtFreqs *mlx.Array,
stepCache *cache.StepCache,
step, cacheInterval, cacheLayers int,
) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
// Check if we should refresh the cache
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
for i, layer := range tr.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached outputs for shallow layers
imgH = stepCache.Get(i)
txtH = stepCache.Get2(i)
} else {
// Compute layer
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
// Cache shallow layers on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, imgH)
stepCache.Set2(i, txtH)
}
}
}
// Final norm with modulation (AdaLayerNormContinuous)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// RoPECache holds precomputed RoPE frequencies
type RoPECache struct {
ImgFreqs *mlx.Array // [L_img, head_dim]
TxtFreqs *mlx.Array // [L_txt, head_dim]
}
// PrepareRoPE computes RoPE for image and text sequences
// This matches Python's QwenEmbedRope with scale_rope=True
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := ComputeAxisFreqs(axesDims[0], theta)
freqsH := ComputeAxisFreqs(axesDims[1], theta)
freqsW := ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
// Image frequencies with scale_rope=True
imgLen := imgH * imgW
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
imgFreqsData := make([]float32, imgLen*headDim)
hHalf := imgH / 2
wHalf := imgW / 2
idx := int32(0)
for y := int32(0); y < imgH; y++ {
for x := int32(0); x < imgW; x++ {
// Frame = 0
for i := 0; i < len(freqsT)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
}
idx += int32(len(freqsT) * 2)
// Height: scale_rope pattern
hNegCount := imgH - hHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
}
}
idx += int32(len(freqsH) * 2)
// Width: scale_rope pattern
wNegCount := imgW - wHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
}
}
idx += int32(len(freqsW) * 2)
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies
maxVidIdx := max(hHalf, wHalf)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
halfDim := dim / 2
freqs := make([]float64, halfDim)
for i := int32(0); i < halfDim; i++ {
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
}
return freqs
}
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
table := make([][]float32, maxIdx)
for idx := int32(0); idx < maxIdx; idx++ {
var pos float64
if negative {
pos = float64(-maxIdx + int32(idx))
} else {
pos = float64(idx)
}
row := make([]float32, len(baseFreqs)*2)
for i, f := range baseFreqs {
angle := pos * f
row[i*2] = float32(math.Cos(angle))
row[i*2+1] = float32(math.Sin(angle))
}
table[idx] = row
}
return table
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// -> [B, pH, pW, C, 2, 2]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// -> [B, pH*pW, C*4]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
channels := shape[2] / (patchSize * patchSize)
pH := H / patchSize
pW := W / patchSize
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// -> [B, C, pH, 2, pW, 2]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// -> [B, C, H, W]
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
// Add temporal dimension for VAE: [B, C, 1, H, W]
return mlx.ExpandDims(x, 2)
}

View File

@@ -1,119 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestTransformerConfig tests configuration invariants.
func TestTransformerConfig(t *testing.T) {
cfg := defaultTransformerConfig()
// Property: hidden_dim = n_heads * head_dim
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
}
// Property: axes_dims_rope sums to head_dim
var ropeSum int32
for _, d := range cfg.AxesDimsRope {
ropeSum += d
}
if ropeSum != cfg.HeadDim {
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
}
// Property: in_channels = out_channels * patch_size^2
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
if cfg.InChannels != expectedIn {
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
}
}
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
func TestTransformerRoPE(t *testing.T) {
cfg := defaultTransformerConfig()
// Test with small image dimensions
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
txtLen := int32(5)
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Verify shapes: [seq_len, head_dim]
imgSeqLen := imgH * imgW
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
}
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
}
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
}
// Verify values are finite
imgData := ropeCache.ImgFreqs.Data()
for i := 0; i < min(100, len(imgData)); i++ {
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
break
}
}
}
// TestTransformerForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestTransformerForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
transformer := &Transformer{}
if err := transformer.Load(weightsPath); err != nil {
t.Fatalf("Failed to load transformer: %v", err)
}
mlx.Keep(mlx.Collect(transformer)...)
cfg := transformer.Config
// Small test inputs
batchSize := int32(1)
imgH, imgW := int32(4), int32(4)
imgSeqLen := imgH * imgW
txtSeqLen := int32(5)
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
// Forward pass
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(out)
// Verify output shape: [batch, img_seq_len, in_channels]
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
gotShape := out.Shape()
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
}
// Verify output is finite
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,854 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
// Input is channels-last: [B, T, H, W, C]
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// RMSNorm3D applies RMS normalization over channels
// Works with channels-last [B, T, H, W, C] format
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
// Use h as working variable, keep x intact for residual (caller will free x)
// Conv handles its own pools, so we just need pools for non-conv operations
var h *mlx.Array
// Keep x so it survives Eval() cleanup - needed for residual connection
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1 (handles its own pools)
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2 (handles its own pools)
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection (shortcut handles its own pools if present)
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V: [B*T, H*W, 3*C]
// Weight is [outC, inC] or [outC, inC, 1, 1]
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0) // [inC, outC]
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// out: [B*T, 1, H*W, C]
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0) // [inC, outC]
out = mlx.Linear(out, p) // [B*T, H*W, C]
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block with staged memory management
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
// ResBlocks handle their own pools
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Upsampler handles its own pools
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
// Uses staged pools to reduce peak memory during 2x upsampling
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block of decoder
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
// Each component handles its own pools; we just free inputs
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading Qwen-Image VAE decoder...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading post_quant_conv... ")
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
m.PostQuantConv = postQuantConv
fmt.Println("✓")
fmt.Print(" Loading conv_in... ")
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
m.ConvIn = convIn
fmt.Println("✓")
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
fmt.Print(" Loading mid_block... ")
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
m.MidBlock = midBlock
fmt.Println("✓")
// Up blocks (reversed dim_mult)
fmt.Print(" Loading up_blocks... ")
numUpBlocks := len(cfg.DimMult)
m.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
m.UpBlocks[i] = upBlock
}
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
fmt.Print(" Loading output layers... ")
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
m.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
m.ConvOut = convOut
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
m := &VAEDecoder{}
if err := m.Load(filepath.Join(path, "vae")); err != nil {
return nil, err
}
return m, nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] normalized latents
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Stage 1a: Denormalize and transpose
{
z = vae.Denormalize(z)
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// Stage 1b: PostQuantConv (handles its own pools)
x = vae.PostQuantConv.Forward(z)
z.Free()
// Stage 1c: ConvIn (handles its own pools)
{
prev := x
x = vae.ConvIn.Forward(x)
prev.Free()
}
// Stage 2: Mid block (handles its own pools)
x = vae.MidBlock.Forward(x)
// Stage 3: Up blocks (each handles its own pools)
for _, upBlock := range vae.UpBlocks {
x = upBlock.Forward(x)
}
// Stage 4a: NormOut + silu
{
prev := x
x = vae.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Stage 4b: ConvOut (handles its own pools)
{
prev := x
x = vae.ConvOut.Forward(x)
prev.Free()
}
// Stage 4c: Post-processing
{
prev := x
// Clamp to [-1, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
// Convert back from channels-last to channels-first
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// Denormalize reverses the normalization applied during encoding
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
mean = mlx.ToBFloat16(mean)
std = mlx.ToBFloat16(std)
return mlx.Add(mlx.Mul(z, std), mean)
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
}
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[2]
W := shape[3]
x = mlx.Transpose(x, 0, 2, 3, 1)
x = mlx.Reshape(x, B*H*W, shape[1])
wShape := weight.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(weight, wShape[0], wShape[1])
} else {
w = weight
}
w = mlx.Transpose(w, 1, 0)
out := mlx.Linear(x, w)
outC := w.Dim(1)
out = mlx.Reshape(out, B, H, W, outC)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
x = pad2D(x, 1, 1, 1, 1)
return conv2D(x, weight, 1, 1)
}
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
x = mlx.Transpose(x, 0, 2, 3, 1)
w = mlx.Transpose(w, 0, 2, 3, 1)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := w.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := extractPatches2D(x, kH, kW, strideH, strideW)
wFlat := mlx.Reshape(w, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
out = mlx.Reshape(out, B, outH, outW, Cout)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * strideH
startW := j * strideW
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[2]
W := shape[3]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 2)
x = mlx.Take(x, colIdx, 3)
return x
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create repeat indices for rows
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
// Create repeat indices for columns
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
// Take along H (axis 1) then W (axis 2)
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
// weight: [outC, kH, kW, inC] (MLX channels-last format)
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
// stride=1, padding=0 (we already padded manually)
return mlx.Conv2d(x, weight, 1, 0)
}

View File

@@ -1,114 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestVAEConfig tests configuration invariants.
func TestVAEConfig(t *testing.T) {
cfg := defaultVAEConfig()
// Property: latents_mean and latents_std have z_dim elements
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
}
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
}
// Property: dim_mult defines 4 stages
if len(cfg.DimMult) != 4 {
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
}
// Property: temperal_downsample has 3 elements (for 3 transitions)
if len(cfg.TemperalDownsample) != 3 {
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
}
}
// TestVAELatentsNormalization tests the latent denormalization values.
func TestVAELatentsNormalization(t *testing.T) {
cfg := defaultVAEConfig()
// Verify latents_std values are all positive
for i, std := range cfg.LatentsStd {
if std <= 0 {
t.Errorf("latents_std[%d] should be positive: %v", i, std)
}
}
// Verify values are in reasonable range (from actual model)
for i, mean := range cfg.LatentsMean {
if math.Abs(float64(mean)) > 5 {
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
}
}
for i, std := range cfg.LatentsStd {
if std > 10 {
t.Errorf("latents_std[%d] seems too large: %v", i, std)
}
}
}
// TestVAEDecoderForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestVAEDecoderForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/vae"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
vae := &VAEDecoder{}
if err := vae.Load(weightsPath); err != nil {
t.Fatalf("Failed to load VAE decoder: %v", err)
}
mlx.Keep(mlx.Collect(vae)...)
// Small test input: [B, C, T, H, W]
// After 4 upsampling stages (2x each), H/W multiply by 16
batchSize := int32(1)
channels := int32(16)
frames := int32(1)
latentH := int32(4)
latentW := int32(4)
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
// Decode
out := vae.Decode(latents)
mlx.Eval(out)
// Verify output shape: [B, 3, T, H*16, W*16]
outShape := out.Shape()
if outShape[0] != batchSize {
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
}
if outShape[1] != 3 {
t.Errorf("channels: got %d, want 3", outShape[1])
}
if outShape[2] != frames {
t.Errorf("frames: got %d, want %d", outShape[2], frames)
}
expectedH := latentH * 16 // 4 stages of 2x upsampling
expectedW := latentW * 16
if outShape[3] != expectedH || outShape[4] != expectedW {
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
outShape[3], outShape[4], expectedH, expectedW)
}
// Verify output is in valid range (should be clamped to [0, 1] by decode)
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,682 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution (or 2D if weight is 4D)
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape()
// Handle both 5D (3D conv) and 4D (2D conv) weights
if len(shape) == 4 {
// 2D conv: [O, I, kH, kW] - need to apply per-frame
return c.forward2D(x)
}
// 3D conv: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := c.Weight.Shape() // [O, I, kH, kW]
kernelH := wShape[2]
kernelW := wShape[3]
outC := wShape[0]
padH := kernelH / 2
padW := kernelW / 2
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad spatially
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
// Apply 2D conv
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = mlx.Conv2d(x, weight, 1, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Get output spatial dims
outH := H
outW := W
// Reshape back to [B, T, H, W, C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// RMSNorm3D applies RMS normalization over channels
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0)
qkv := mlx.Linear(x, w)
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0)
out = mlx.Linear(out, p)
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
return mlx.Conv2d(x, weight, 1, 0)
}
// conv2DStrided applies conv with stride > 1 using manual patch extraction
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := weight.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := extractPatches2DStrided(x, kH, kW, stride)
wFlat := mlx.Reshape(weight, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outH, outW, Cout)
}
// conv3DStrided applies 3D conv with strides using manual patch extraction
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
// strideT, strideH, strideW are the strides for each dimension
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
wShape := weight.Shape()
Cout := wShape[0]
// I := wShape[1]
kT := wShape[2]
kH := wShape[3]
kW := wShape[4]
// For temporal: if T < kT, we need to repeat frames temporally
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
// Python Qwen2.5-VL duplicates the frame, not zero-pads
if T < kT {
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
T = kT
}
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
// Extract 3D patches in [C, T, H, W] order to match Python
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outT, outH, outW, Cout)
}
// extractPatches3DStrided extracts 3D patches with given strides
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
numPatches := outT * outH * outW
patches := make([]*mlx.Array, numPatches)
idx := 0
for t := int32(0); t < outT; t++ {
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startT := t * strideT
startH := i * strideH
startW := j * strideW
// Extract patch: [B, kT, kH, kW, C]
patch := mlx.Slice(x,
[]int32{0, startT, startH, startW, 0},
[]int32{B, startT + kT, startH + kH, startW + kW, C})
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
// Flatten to [B, C*T*H*W]
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
patches[idx] = patch
idx++
}
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
}
// extractPatches2DStrided extracts patches with given stride
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * stride
startW := j * stride
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}

View File

@@ -1,475 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
_ "golang.org/x/image/webp"
)
// loadImageFile loads an image from disk
func loadImageFile(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return img, nil
}
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
pixels := make([]float32, width*height*3)
idx := 0
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
r, g, b, _ := img.At(x, y).RGBA()
pixels[idx] = float32(r) / 65535.0
pixels[idx+1] = float32(g) / 65535.0
pixels[idx+2] = float32(b) / 65535.0
idx += 3
}
}
return pixels
}
// normalizeImageNet applies ImageNet normalization to an image tensor
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
return mlx.Div(mlx.Sub(arr, mean), std)
}
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch dimension [1, C, H, W]
arr = mlx.ExpandDims(arr, 0)
// Convert to bf16
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// clampFloat clamps a value to [0, 255] and returns uint8
func clampFloat(v, weightSum float64) uint8 {
v /= weightSum
if v < 0 {
v = 0
}
if v > 255 {
v = 255
}
return uint8(math.Round(v))
}
// ImageDims holds dimensions for a preprocessed image
type ImageDims struct {
// Original image dimensions
OrigW, OrigH int32
// Condition image dimensions (for vision encoder)
CondW, CondH int32
// VAE image dimensions
VaeW, VaeH int32
// Latent dimensions (VAE dims / vae_scale_factor)
LatentW, LatentH int32
// Patch dimensions (latent dims / patch_size)
PatchW, PatchH int32
}
// ProcessorConfig holds image processor configuration
type ProcessorConfig struct {
// Condition image size (target pixel area for vision encoder input)
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
// Pipeline resizes image to this area before passing to encode_prompt
ConditionImageSize int32
// VAE image size (target pixel area)
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
VAEImageSize int32
// Image normalization (ImageNet stats for vision encoder)
ImageMean []float32
ImageStd []float32
}
// defaultProcessorConfig returns default processor config
func defaultProcessorConfig() *ProcessorConfig {
return &ProcessorConfig{
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// Processor handles image preprocessing for Qwen-Image-Edit
type Processor struct {
Config *ProcessorConfig
}
// Load loads the processor config
func (p *Processor) Load(path string) error {
p.Config = defaultProcessorConfig()
return nil
}
// LoadAndPreprocess loads an image and preprocesses it for both paths
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, err
}
bounds := img.Bounds()
origW := bounds.Dx()
origH := bounds.Dy()
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Preprocess for condition (vision encoder) - two-step resize
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
return condImage, vaeImage, nil
}
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
// Used by edit_plus pipeline for multi-image input
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImage converts image to tensor for vision encoder
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
resized := resizeImageBicubic(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
if normalize {
arr = p.normalizeImageNet(arr)
}
return prepareImageTensor(arr)
}
// preprocessImageForVAE converts image to tensor for VAE encoding
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
// Scale to [-1, 1]: arr * 2 - 1
arr = mlx.MulScalar(arr, 2.0)
arr = mlx.AddScalar(arr, -1.0)
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch and temporal dimensions [1, C, 1, H, W]
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// smartResize implements Python Qwen2VL processor's smart_resize logic
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
// Round to factor
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
// Ensure minimum factor size
if hBar < factor {
hBar = factor
}
if wBar < factor {
wBar = factor
}
// Check pixel constraints
total := hBar * wBar
if total > maxPixels {
// Scale down
beta := math.Sqrt(float64(maxPixels) / float64(total))
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
} else if total < minPixels {
// Scale up
beta := math.Sqrt(float64(minPixels) / float64(total))
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
// calculateDimensions calculates width and height for a target area while maintaining ratio
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
width := math.Sqrt(float64(targetArea) * ratio)
height := width / ratio
m := float64(multiple)
width = math.Round(width/m) * m
height = math.Round(height/m) * m
// Ensure minimum dimensions
if width < m {
width = m
}
if height < m {
height = m
}
return int32(width), int32(height)
}
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
func resizeImageLanczos(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
dst := image.NewRGBA(image.Rect(0, 0, width, height))
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
lanczos3 := &draw.Kernel{
Support: 3.0,
At: func(t float64) float64 {
if t == 0 {
return 1.0
}
if t < 0 {
t = -t
}
if t >= 3.0 {
return 0.0
}
// sinc(t) * sinc(t/3)
piT := math.Pi * t
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
},
}
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
return dst
}
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
// Uses separable interpolation with PIL's coordinate mapping for exact match
func resizeImageBicubic(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
srcW := bounds.Dx()
srcH := bounds.Dy()
// Convert to RGBA if needed
var src *image.RGBA
if rgba, ok := img.(*image.RGBA); ok {
src = rgba
} else {
src = image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
src.Set(x, y, img.At(x, y))
}
}
}
// Keys cubic with a=-0.5 (PIL BICUBIC)
cubic := func(x float64) float64 {
if x < 0 {
x = -x
}
if x < 1 {
return 1.5*x*x*x - 2.5*x*x + 1
}
if x < 2 {
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
}
return 0
}
// Horizontal pass: srcW -> width, keep srcH rows
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
for y := 0; y < srcH; y++ {
for dstX := 0; dstX < width; dstX++ {
// PIL coordinate mapping: center-to-center
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
baseX := int(math.Floor(srcXf))
var sumR, sumG, sumB, sumA, weightSum float64
for i := -1; i <= 2; i++ {
sx := baseX + i
if sx < 0 {
sx = 0
}
if sx >= srcW {
sx = srcW - 1
}
w := cubic(math.Abs(srcXf - float64(baseX+i)))
c := src.RGBAAt(sx, y)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
temp.SetRGBA(dstX, y, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
// Vertical pass: srcH -> height
dst := image.NewRGBA(image.Rect(0, 0, width, height))
for x := 0; x < width; x++ {
for dstY := 0; dstY < height; dstY++ {
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
baseY := int(math.Floor(srcYf))
var sumR, sumG, sumB, sumA, weightSum float64
for j := -1; j <= 2; j++ {
sy := baseY + j
if sy < 0 {
sy = 0
}
if sy >= srcH {
sy = srcH - 1
}
w := cubic(math.Abs(srcYf - float64(baseY+j)))
c := temp.RGBAAt(x, sy)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
dst.SetRGBA(x, dstY, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
return dst
}
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
const vaeScaleFactor int32 = 8
const patchSize int32 = 2
condImages := make([]*mlx.Array, len(imagePaths))
vaeImages := make([]*mlx.Array, len(imagePaths))
dims := make([]ImageDims, len(imagePaths))
for i, imagePath := range imagePaths {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
}
bounds := img.Bounds()
origW := int32(bounds.Dx())
origH := int32(bounds.Dy())
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Calculate derived dimensions
latentW := vaeW / vaeScaleFactor
latentH := vaeH / vaeScaleFactor
patchW := latentW / patchSize
patchH := latentH / patchSize
dims[i] = ImageDims{
OrigW: origW,
OrigH: origH,
CondW: condW,
CondH: condH,
VaeW: vaeW,
VaeH: vaeH,
LatentW: latentW,
LatentH: latentH,
PatchW: patchW,
PatchH: patchH,
}
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
}
return condImages, vaeImages, dims, nil
}

View File

@@ -1,625 +0,0 @@
//go:build mlx
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
// It reuses components from qwen_image where possible.
package qwen_image_edit
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
}
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
Processor *Processor // Image processor for vision encoder
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
VAE *VAE // Combined encoder + decoder
}
// Load loads the Qwen-Image-Edit model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image-Edit model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer from processor directory
fmt.Print(" Loading tokenizer... ")
processorPath := filepath.Join(modelPath, "processor")
tok, err := tokenizer.Load(processorPath)
if err != nil {
// Fallback to tokenizer directory
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err = tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
}
m.Tokenizer = tok
fmt.Println("✓")
// Load processor (image preprocessing config)
fmt.Print(" Loading processor... ")
m.Processor = &Processor{}
if err := m.Processor.Load(processorPath); err != nil {
return fmt.Errorf("processor: %w", err)
}
fmt.Println("✓")
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
m.TextEncoder = &qwen_image.Qwen25VL{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer (reuse qwen_image)
m.Transformer = &qwen_image.Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE (encoder + decoder)
m.VAE = &VAE{}
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE: %w", err)
}
mlx.Eval(mlx.Collect(m.VAE)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Edit edits an image based on a text prompt.
// inputImagePath: path to input image
// prompt: text description of desired edit
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// EditFromConfig edits images using the unified config struct.
// Accepts one or more input images.
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
if len(inputImagePaths) == 0 {
return nil, fmt.Errorf("no input images provided")
}
start := time.Now()
result, err := m.edit(inputImagePaths, cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// EditImage implements model.ImageEditModel interface.
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
}
// EditMultiImage edits using multiple source images.
// This matches diffusers' QwenImageEditPlusPipeline behavior.
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
return m.EditFromConfig(inputImagePaths, cfg)
}
// edit is the internal editing pipeline that handles one or more images.
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// Load and preprocess all input images
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
if err != nil {
return nil, fmt.Errorf("preprocess images: %w", err)
}
for _, img := range condImages {
mlx.Keep(img)
}
for _, img := range vaeImages {
mlx.Keep(img)
}
mlx.Eval(append(condImages, vaeImages...)...)
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
vaeScaleFactor := int32(8)
// Output dimensions - if not specified, use first input image dimensions
if cfg.Width <= 0 {
cfg.Width = inputDims[0].VaeW
}
if cfg.Height <= 0 {
cfg.Height = inputDims[0].VaeH
}
// Output (noise) latent dimensions
outLatentH := cfg.Height / vaeScaleFactor
outLatentW := cfg.Width / vaeScaleFactor
outPH := outLatentH / tcfg.PatchSize
outPW := outLatentW / tcfg.PatchSize
noiseSeqLen := outPH * outPW
imgSeqLen := noiseSeqLen
// Encode prompt with all images for conditioning
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding prompt: %w", err)
}
mlx.Keep(posEmb)
mlx.Eval(posEmb)
var negEmb *mlx.Array
if useCFG {
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding negative prompt: %w", err)
}
mlx.Keep(negEmb)
mlx.Eval(negEmb)
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
for i, vaeImage := range vaeImages {
imageLatents := m.VAE.Encode(vaeImage)
imageLatents = m.VAE.Normalize(imageLatents)
imageLatents2D := mlx.Squeeze(imageLatents, 2)
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
mlx.Keep(packed)
mlx.Eval(packed)
allImageLatentsPacked[i] = packed
}
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
mlx.Keep(imageLatentsPacked)
mlx.Eval(imageLatentsPacked)
// Scheduler
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
// Init noise latents in packed format
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
mlx.Eval(latents)
// RoPE cache
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Denoising loop
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
mlx.Eval(timestep)
latents2D := mlx.Squeeze(latents, 2)
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
}
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
}
// Free denoising temporaries
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()
// Decode latents
decoded := m.decodeAndPostprocess(latents)
latents.Free()
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
// This prevents CFG from inflating magnitude too much.
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
// Upcast to float32 for precision
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
// CFG: pred = neg + scale * (pos - neg)
diff := mlx.Sub(posF32, negF32)
scaledDiff := mlx.MulScalar(diff, scale)
combPred := mlx.Add(negF32, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
mlx.Eval(output)
return mlx.ToBFloat16(output)
}
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
latents = m.VAE.Denormalize(latents)
decoded := m.VAE.Decode(latents)
// Post-process: squeeze temporal dim and rescale to [0, 1]
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
mlx.Eval(decoded)
return decoded
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
// Handles single or multiple input images with different resolutions.
//
// Parameters:
// - outPH, outPW: output patch dimensions (noise latent resolution)
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
// - txtLen: text sequence length
// - axesDims: RoPE axis dimensions [16, 56, 56]
//
// Returns RoPE cache where:
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
// - Following positions are for each input image (interpolated from output res)
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
// Helper to compute RoPE for a single position at output resolution with scale_rope
computePosFreqs := func(framePos, y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame position
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = posFreqsT[framePos][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Helper to compute RoPE for frame -1 (used for last condition image)
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
computeNegFrameFreqs := func(y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame -1: use last row of negative frame frequencies
negFrameIdx := maxIdx - 1
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = negFreqsT[negFrameIdx][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Total image sequence length: noise + all input images
noiseSeqLen := outPH * outPW
totalImgLen := noiseSeqLen
for _, dims := range inputDims {
totalImgLen += dims.PatchH * dims.PatchW
}
imgFreqsData := make([]float32, totalImgLen*headDim)
idx := int32(0)
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
for y := int32(0); y < outPH; y++ {
for x := int32(0); x < outPW; x++ {
row := computePosFreqs(0, y, x)
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
// For multiple images: Python uses frame -1 for the LAST condition image
// (_compute_condition_freqs), positive indices for others.
numImages := len(inputDims)
lastImgIdx := numImages - 1
for imgIdx, dims := range inputDims {
inPH := dims.PatchH
inPW := dims.PatchW
// Determine frame index for this image
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
// Map each input position to an output position using linear interpolation
for y := int32(0); y < inPH; y++ {
for x := int32(0); x < inPW; x++ {
// Interpolate: map input (y, x) to output grid position
// This is the key fix from DiffSynth's forward_sampling
var yOut, xOut int32
if inPH == 1 {
yOut = 0
} else {
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
yOut = y * (outPH - 1) / (inPH - 1)
}
if inPW == 1 {
xOut = 0
} else {
xOut = x * (outPW - 1) / (inPW - 1)
}
var row []float32
if useNegFrame {
// Last image in multi-image uses frame -1
row = computeNegFrameFreqs(yOut, xOut)
} else {
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
frameIdx := int32(imgIdx + 1)
row = computePosFreqs(frameIdx, yOut, xOut)
}
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies - start after max video index
maxVidIdx := max(outPH/2, outPW/2)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &qwen_image.RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}

View File

@@ -1,249 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"math"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000)
// Expected values from Python:
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
expectedFreqsT := []float64{
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
}
expectedFreqsH_first4 := []float64{
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
}
expectedFreqsH_last4 := []float64{
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
}
// Test temporal frequencies (dim=16)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
if len(freqsT) != 8 {
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
}
for i, expected := range expectedFreqsT {
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
}
}
// Test height/width frequencies (dim=56)
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
if len(freqsH) != 28 {
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
}
for i, expected := range expectedFreqsH_first4 {
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
}
}
for i, expected := range expectedFreqsH_last4 {
idx := 24 + i // last 4 of 28
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
}
}
}
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
func TestMakeFreqTable(t *testing.T) {
theta := float64(10000)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
maxIdx := int32(4096)
// Test positive table
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
// Position 0 should give cos=1, sin=0 for all frequencies
for i := 0; i < len(freqsT)*2; i += 2 {
if posTable[0][i] != 1.0 {
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
}
if posTable[0][i+1] != 0.0 {
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
}
}
// Position 1, first frequency (1.0): angle = 1*1 = 1
// cos(1) = 0.5403, sin(1) = 0.8415
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
}
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
}
// Test negative table
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
// negTable[4095] corresponds to position -1
// cos(-1) = cos(1), sin(-1) = -sin(1)
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
}
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
}
// negTable[4094] corresponds to position -2
// cos(-2) = cos(2), sin(-2) = -sin(2)
cos2 := math.Cos(2.0)
sin2 := math.Sin(2.0)
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
}
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
}
}
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
func TestPrepareRoPE_QwenImage(t *testing.T) {
if !mlx.GPUIsAvailable() {
t.Skip("GPU not available")
}
mlx.SetDefaultDeviceCPU()
// 4x4 patch grid, single image
imgH, imgW := int32(4), int32(4)
txtLen := int32(5)
axesDims := []int32{16, 56, 56}
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
// Check shapes
imgShape := cache.ImgFreqs.Shape()
if imgShape[0] != 16 { // 4*4 patches
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
}
// For single image (frame=0), all temporal values should be cos=1, sin=0
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
mlx.Eval(imgFreqsCPU)
imgData := imgFreqsCPU.Data()
// Check first 16 values of patch 0 (temporal cos/sin pairs)
for i := 0; i < 16; i += 2 {
cosVal := imgData[i]
sinVal := imgData[i+1]
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
}
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
}
}
cache.ImgFreqs.Free()
cache.TxtFreqs.Free()
}
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
func TestScaleRopePositions(t *testing.T) {
// For a 4x4 grid with scale_rope=True:
// hHalf = 2, wHalf = 2
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
//
// Height positions:
// y=0: -(4-2) + 0 = -2
// y=1: -(4-2) + 1 = -1
// y=2: 2 - 2 = 0
// y=3: 3 - 2 = 1
//
// Same for width
pH, pW := int32(4), int32(4)
hHalf := pH / 2
wHalf := pW / 2
hNegCount := pH - hHalf
wNegCount := pW - wHalf
expectedH := []int32{-2, -1, 0, 1}
expectedW := []int32{-2, -1, 0, 1}
for y := int32(0); y < pH; y++ {
var hPos int32
if y < hNegCount {
hPos = -(pH - hHalf) + y
} else {
hPos = y - hNegCount
}
if hPos != expectedH[y] {
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
}
}
for x := int32(0); x < pW; x++ {
var wPos int32
if x < wNegCount {
wPos = -(pW - wHalf) + x
} else {
wPos = x - wNegCount
}
if wPos != expectedW[x] {
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
}
}
}
// TestRoPEHeadDimensions verifies the head dimension breakdown
func TestRoPEHeadDimensions(t *testing.T) {
// axes_dims_rope = [16, 56, 56]
// Each dimension uses half the values for frequencies
// So we get: 8 + 28 + 28 = 64 frequency values
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
axesDims := []int32{16, 56, 56}
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
expectedHeadDim := expectedFreqs * 2
if expectedFreqs != 64 {
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
}
if expectedHeadDim != 128 {
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
}
// This should match the transformer's attention head dimension
// hidden_size = 3072, num_heads = 24
// head_dim = 3072 / 24 = 128
}

View File

@@ -1,642 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// VAE is the full VAE with encoder and decoder
type VAE struct {
Config *VAEConfig
Encoder *VAEEncoder
Decoder *VAEDecoder
}
// Load loads the VAE from a directory
func (m *VAE) Load(path string) error {
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load weights as f32 for quality (matches Python default behavior)
// VAE decoder precision is critical for final image quality
fmt.Print(" Loading weights as f32... ")
if err := weights.Load(mlx.DtypeFloat32); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
// Load encoder
fmt.Print(" Loading encoder... ")
m.Encoder = &VAEEncoder{}
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
fmt.Println("✓")
// Load decoder
fmt.Print(" Loading decoder... ")
m.Decoder = &VAEDecoder{}
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("decoder: %w", err)
}
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor in [-1, 1] range
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
return m.Encoder.Encode(x)
}
// Decode decodes latents to image
// z: [B, C, T, H, W] latents (denormalized)
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
return m.Decoder.Decode(z)
}
// Normalize applies latent normalization
// Input z should be f32 (from VAE encoder), output is f32 for transformer
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
// Mean/std are f32, will match z dtype through broadcasting
return mlx.Div(mlx.Sub(z, mean), std)
}
// Denormalize reverses latent normalization
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
// Convert latents to f32 for VAE decoder quality
z = mlx.AsType(z, mlx.DtypeFloat32)
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
return mlx.Add(mlx.Mul(z, std), mean)
}
// VAEEncoder is the encoder part of the VAE
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
// - Blocks 0,1: ResBlocks (base_dim)
// - Block 2: Downsample
// - Blocks 3,4: ResBlocks (base_dim*2)
// - Block 5: Downsample + temporal
// - Blocks 6,7: ResBlocks (base_dim*4)
// - Block 8: Downsample + temporal
// - Blocks 9,10: ResBlocks (base_dim*4)
type VAEEncoder struct {
Config *VAEConfig
ConvIn *CausalConv3d
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
MidBlock *MidBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
QuantConv *CausalConv3d
}
// EncoderBlock is either a ResBlock or a Downsample
type EncoderBlock interface {
Forward(x *mlx.Array) *mlx.Array
IsDownsample() bool
}
// EncoderResBlock wraps ResBlock
type EncoderResBlock struct {
*ResBlock
}
func (b *EncoderResBlock) IsDownsample() bool { return false }
// EncoderDownsample is a downsample layer
type EncoderDownsample struct {
Resample *CausalConv3d
TimeConv *CausalConv3d // Optional temporal downsample
}
func (d *EncoderDownsample) IsDownsample() bool { return true }
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
// Spatial downsample with stride 2
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
x = d.forwardSpatialDownsample(x)
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
// The Python forward checks: if feat_cache is not None ... then use time_conv
// Since we don't support streaming, we skip time_conv entirely.
return x
}
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := d.Resample.Weight.Shape()
outC := wShape[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
// Apply 2D conv with stride 2
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = conv2DStrided(x, weight, 2)
if d.Resample.Bias != nil {
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Output dims after stride 2: (H+1)/2, (W+1)/2
outH := (H + 1) / 2
outW := (W + 1) / 2
// Reshape back to [B, T, H', W', C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// loadFromWeights loads the encoder from pre-loaded weights
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
e.Config = cfg
// Conv in
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
if err != nil {
return err
}
e.ConvIn = convIn
// Encoder uses flat block structure:
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
e.Blocks = make([]EncoderBlock, 0, 11)
// Track dimensions
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
blockIdx := 0
for stage := 0; stage < len(cfg.DimMult); stage++ {
inDim := cfg.BaseDim
if stage > 0 {
inDim = dims[stage-1]
}
outDim := dims[stage]
// ResBlocks for this stage (num_res_blocks per stage)
for r := int32(0); r < cfg.NumResBlocks; r++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
currentInDim := inDim
if r > 0 {
currentInDim = outDim
}
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
if err != nil {
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, block)
blockIdx++
}
// Downsample after each stage except the last
if stage < len(cfg.DimMult)-1 {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
if err != nil {
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, down)
blockIdx++
}
}
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
if err != nil {
return err
}
e.MidBlock = midBlock
// Norm out
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
if err != nil {
return err
}
e.NormOut = normOut
// Conv out
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
if err != nil {
return err
}
e.ConvOut = convOut
// Quant conv
quantConv, err := newCausalConv3d(weights, "quant_conv")
if err != nil {
return err
}
e.QuantConv = quantConv
return nil
}
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
block, err := newResBlock(weights, prefix, inDim, outDim)
if err != nil {
return nil, err
}
return &EncoderResBlock{block}, nil
}
// newEncoderDownsample creates a downsample layer for the encoder
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
resample, err := newCausalConv3d(weights, prefix+".resample.1")
if err != nil {
return nil, err
}
var timeConv *CausalConv3d
if temporal {
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
}
return &EncoderDownsample{
Resample: resample,
TimeConv: timeConv,
}, nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor (channels-first)
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
mlx.Eval(x)
// Conv in
x = e.ConvIn.Forward(x)
// Encoder blocks (mix of ResBlocks and Downsamplers)
for _, block := range e.Blocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Mid block
x = e.MidBlock.Forward(x)
// Norm + silu
{
prev := x
x = e.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Conv out
{
prev := x
x = e.ConvOut.Forward(x)
prev.Free()
}
// Quant conv
{
prev := x
x = e.QuantConv.Forward(x)
prev.Free()
}
// Get mode from distribution (first half of channels = mean)
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
shape := x.Shape()
latentC := shape[4] / 2
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
// Convert back to channels-first [N, C, T, H, W]
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
mlx.Eval(x)
return x
}
// VAEDecoder is the decoder part of the VAE
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// loadFromWeights loads the decoder from pre-loaded weights
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
d.Config = cfg
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
d.PostQuantConv = postQuantConv
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
d.ConvIn = convIn
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
d.MidBlock = midBlock
// Up blocks (reversed dim_mult)
numUpBlocks := len(cfg.DimMult)
d.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
d.UpBlocks[i] = upBlock
}
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
d.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
d.ConvOut = convOut
return nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] denormalized latents
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Convert from channels-first to channels-last
{
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// PostQuantConv
x = d.PostQuantConv.Forward(z)
z.Free()
// ConvIn
{
prev := x
x = d.ConvIn.Forward(x)
prev.Free()
}
// Mid block
x = d.MidBlock.Forward(x)
// Up blocks
for _, upBlock := range d.UpBlocks {
x = upBlock.Forward(x)
}
// NormOut + silu
{
prev := x
x = d.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// ConvOut
{
prev := x
x = d.ConvOut.Forward(x)
prev.Free()
}
// Post-processing: clamp and convert back to channels-first
{
prev := x
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// DownBlock handles downsampling in encoder
type DownBlock struct {
ResBlocks []*ResBlock
Downsampler *Downsample
}
// newDownBlock creates a down block
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var downsampler *Downsample
if downsampleMode != "" {
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
}
return &DownBlock{
ResBlocks: resBlocks,
Downsampler: downsampler,
}, nil
}
// Forward applies down block
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range d.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if d.Downsampler != nil {
prev := x
x = d.Downsampler.Forward(x)
prev.Free()
}
return x
}
// Downsample handles spatial downsampling
type Downsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newDownsample creates a downsampler
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Downsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies downsampling to channels-last input [B, T, H, W, C]
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := d.Conv.Shape()[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
// For 3x3 stride 2: pad 1 on all sides
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv with stride 2 using manual strided patching
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
x = conv2DStrided(x, weight, 2)
if d.Bias != nil {
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
mlx.Eval(x)
return x
}

View File

@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
// Note: For modes like "nvfp4", qbiases will be nil.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
mlx.Eval(qw, scales, qbiases)
// Handle modes that don't return qbiases (e.g., nvfp4)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
// Supports multiple quantization modes:
// - "affine": scale + zero-point bias (QBiases required)
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (NOT layer bias)
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int

View File

@@ -1,233 +0,0 @@
//go:build mlx
// Package runner provides a subprocess server for image generation.
// It listens on a port and handles HTTP requests for image generation.
package runner
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model ImageModel
modelName string
}
// Execute is the entry point for the image runner subprocess
func Execute(args []string) error {
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to image model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
err := mlx.InitMLX()
if err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Check memory requirements before loading
requiredMemory := imagegen.EstimateVRAM(*modelName)
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && availableMemory < requiredMemory {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(*modelName)
slog.Info("detected model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down image runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("image runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps
// Only seed needs to be set here if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Generate image using the common interface
ctx := r.Context()
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

View File

@@ -17,17 +17,26 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "FP4", "FP8", or ""
Quantization() string // Returns "NVFP4", "FP4", "FP8", or ""
}
// quantizationParams returns groupSize, bits, mode for a quantization type.
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
// QuantizationParams returns groupSize, bits, mode for a quantization type.
// MLX quantization modes:
// - "affine": scale + zero-point bias, group_size=32/64/128
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "FP4":
case "NVFP4":
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
// 4-bit quantization with affine mode (scale + qbias)
return 32, 4, "affine"
case "FP8", "Q8", "INT8", "":
// 8-bit quantization with affine mode (default for quantized models)
return 32, 8, "affine"
default:
return 32, 8, "affine" // FP8 or unknown
return 32, 8, "affine" // Default to affine
}
}
@@ -122,7 +131,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
if field.Type == linearLayerType {
if !hasTag {
continue // no tag = skip
}
@@ -217,11 +227,12 @@ func joinPath(prefix, suffix string) string {
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
hasScale := weights.HasTensor(scalePath)
if hasScale {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
@@ -245,9 +256,11 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := quantizationParams(weights.Quantization())
groupSize, bits, mode := QuantizationParams(weights.Quantization())
if mlx.MetalIsAvailable() {
// NVFP4 doesn't have native quantized matmul kernels in MLX yet,
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
if mlx.MetalIsAvailable() && mode != "nvfp4" {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,

View File

@@ -1,82 +0,0 @@
package imagegen
import (
"runtime"
"testing"
)
// TestPlatformSupport verifies platform validation works correctly.
func TestPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
// Apple Silicon should be supported
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
// Intel Mac should fail
if err == nil {
t.Error("Expected error on darwin/amd64 (Intel), got nil")
}
if err != nil && err.Error() == "" {
t.Error("Expected meaningful error message for unsupported platform")
}
}
case "linux", "windows":
// Linux/Windows are allowed (CUDA support checked at runtime)
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
// Other platforms should fail
if err == nil {
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
}
}
}
// TestMemoryRequirementsError verifies memory check returns clear error.
func TestMemoryRequirementsError(t *testing.T) {
// Test with insufficient memory
err := CheckMemoryRequirements("test-model", 8*GB)
if err == nil {
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
}
// Test with sufficient memory
err = CheckMemoryRequirements("test-model", 32*GB)
if err != nil {
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
}
}
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
// Unknown model should return default (21GB)
vram := EstimateVRAM("unknown-model")
if vram < 10*GB || vram > 100*GB {
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
}
// Verify known pipeline estimates exist and are reasonable
for name, estimate := range modelVRAMEstimates {
if estimate < 10*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
}
if estimate > 200*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
}
}
}
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
// This is a compile-time check but we document it as a test.
func TestServerInterfaceCompliance(t *testing.T) {
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
// ensures compile-time interface compliance.
// This test documents that requirement.
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
}

View File

@@ -44,23 +44,54 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
}, nil
}
// LoadAllWeightsFromManifest creates a weight loader for all tensors without component filtering.
// Used for LLM models where tensors don't have a component prefix.
func LoadAllWeightsFromManifest(manifest *ModelManifest) (*ManifestWeights, error) {
layers := manifest.GetAllTensorLayers()
if len(layers) == 0 {
return nil, fmt.Errorf("no tensor layers found in manifest")
}
tensors := make(map[string]ManifestLayer, len(layers))
for _, layer := range layers {
tensors[layer.Name] = layer
}
return &ManifestWeights{
manifest: manifest,
tensors: tensors,
cache: make(map[string]*mlx.Array),
}, nil
}
// Load loads all tensor blobs using native mmap (zero-copy).
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
// If dtype is non-zero, tensors are converted to the specified dtype.
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
// Track native handles to free after batch eval
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
arrays := make([]*mlx.Array, 0, len(mw.tensors))
for name, layer := range mw.tensors {
path := mw.manifest.BlobPath(layer.Digest)
// Load blob as safetensors (native mmap, zero-copy)
sf, err := mlx.LoadSafetensorsNative(path)
if err != nil {
// Free any handles we've accumulated
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("load %s: %w", name, err)
}
nativeHandles = append(nativeHandles, sf)
// Blob contains single tensor named "data"
arr := sf.Get("data")
if arr == nil {
sf.Free()
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
}
@@ -68,11 +99,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
}
// ALWAYS make a contiguous copy to ensure independence from mmap
// Make contiguous copy to ensure independence from mmap
arr = mlx.Contiguous(arr)
mlx.Eval(arr)
mw.cache[name] = arr
sf.Free() // Safe to free - arr is now an independent copy
arrays = append(arrays, arr)
}
// Batch evaluate all tensors at once (much faster than one at a time)
mlx.Eval(arrays...)
// Now safe to free all native handles
for _, sf := range nativeHandles {
sf.Free()
}
return nil
@@ -107,18 +145,95 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
}
// Quantization returns the model's quantization type from model_index.json.
// Returns empty string if not quantized or unknown.
// Returns empty string if not quantized.
// Falls back to detecting from tensor names and shapes if not in config.
func (mw *ManifestWeights) Quantization() string {
if mw.manifest == nil {
return ""
}
// Try to read from model_index.json first
var index struct {
Quantization string `json:"quantization"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
return index.Quantization
}
// Fallback: detect from tensor names
// Check if any tensors have _scale suffix (indicates quantization)
hasScales := false
hasQBias := false
for name := range mw.tensors {
if strings.HasSuffix(name, ".weight_scale") {
hasScales = true
}
if strings.HasSuffix(name, ".weight_qbias") {
hasQBias = true
}
}
if !hasScales {
// No scales = not quantized
return ""
}
return index.Quantization
// Has scales but no qbias = NVFP4 (or other non-affine mode)
if !hasQBias {
return "NVFP4"
}
// Has both scales and qbias = affine mode
// Need to determine FP4 vs FP8 from tensor shapes
// FP4: weight last dim is 1/8 of scales last dim * group_size
// FP8: weight last dim is 1/4 of scales last dim * group_size
//
// For affine mode with group_size=32:
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
// scales_dim = orig_dim / group_size
// So: weight_dim / scales_dim = group_size / pack_factor
// FP4: ratio = 32/8 = 4
// FP8: ratio = 32/4 = 8
// Find a weight/scale pair to check the ratio
for name := range mw.tensors {
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
continue
}
scaleName := name + "_scale"
if _, ok := mw.tensors[scaleName]; !ok {
continue
}
// Load both tensors to check shapes
weightLayer := mw.tensors[name]
scaleLayer := mw.tensors[scaleName]
// Get shapes from manifest layer metadata if available
// For now, default to FP4 since it's more common
// The actual shape check would require loading the tensor
// Simple heuristic: check if scale tensor is ~4x smaller than weight
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
// Rough size heuristic (assuming float16 scales)
// FP4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
// FP8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
if ratio < 12 {
// Closer to 8 = FP4
return "FP4"
}
// Closer to 16 = FP8
return "FP8"
}
// Default to FP4 for affine mode (most common)
return "FP4"
}
// ReleaseAll frees all native handles and clears the tensor cache.

View File

@@ -1,797 +1,144 @@
//go:build mlx
package kvcache
// import (
// "errors"
// "fmt"
// "log/slog"
// "math"
// "slices"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// // Causal cache stores K and V tensors according to their position in the
// // sequence. Returns the history and a mask for attending to past tokens
// //
// // The tensors are of shape embed dim, kv heads, batch size
// // The mask is of shape history size, batch size
// type Causal struct {
// DType ml.DType
// // swaWindowSize is the number of tokens that will be included in the mask
// // during attention operations. swaMemorySize is the number of tokens that
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
// // for unlimited or if sliding window attention is not being used.
// swaWindowSize int32
// swaMemorySize int32
// chunkSize int32
// opts CausalOptions
// // maxBatch is the largest batch that we might receive
// maxBatch int
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // size of the current batch
// curBatchSize int
// // locations for data storage for this batch
// curLoc ml.Tensor
// // mask of the cache as used by this batch
// curMask ml.Tensor
// // the active layer for Get and Put
// curLayer int
// // locations in the cache that are needed for this batch
// curCellRange cellRange
// // curSequences is the sequences corresponding to this pass's entries in the cache
// curSequences []int
// // curPositions is the positions corresponding to this pass's entries in the cache
// curPositions []int32
// // ** cache metadata **
// // for each possible location in the cache, stores the position and set of sequences
// // that reference the data there
// cells []cacheCell
// // maps from sequence to the range of locations where it is stored in the cache
// cellRanges map[int]cellRange
// // ** cache data storage **
// shiftFn shiftFn
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// kHeadDims, vHeadDims, numKVHeads map[int]int
// }
// type cacheCell struct {
// pos int32
// sequences []int
// }
// type cellRange struct {
// min int
// max int
// }
// func NewCausalCache(shift shiftFn) *Causal {
// return &Causal{
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// swaMemorySize: memorySize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
// return &Causal{
// chunkSize: chunkSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if c.config.CachePadding == 0 {
// c.config.CachePadding = 1
// }
// if c.config.MaskBatchPadding == 0 {
// c.config.MaskBatchPadding = 1
// }
// // TODO what types do we handle here?
// // if c.config.MaskDType == ml.DTypeOther {
// // c.config.MaskDType = ml.DTypeFloat32
// // }
// if c.swaWindowSize == 0 {
// c.swaWindowSize = math.MaxInt32
// }
// if c.swaMemorySize == 0 {
// c.swaMemorySize = c.swaWindowSize
// }
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
// // causing a cache break. As an optimization, only do this when we have parallel sequences
// // because the extra token will live in the batch buffer and won't get overwritten if we
// // only have a single sequence.
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
// }
// if int(c.swaMemorySize) >= capacity {
// c.swaMemorySize = math.MaxInt32
// }
// if c.swaMemorySize < c.swaWindowSize {
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
// }
// var cacheSize int
// if c.swaMemorySize == math.MaxInt32 {
// cacheSize = maxSequences * capacity
// } else {
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
// }
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
// c.cells = make([]cacheCell, cacheSize)
// c.DType = dtype
// c.cellRanges = make(map[int]cellRange)
// c.backend = backend
// c.maxBatch = maxBatch
// }
// func (c *Causal) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
// // panic("XXX Causal.StartForward")
// c.curBatchSize = len(batch.Positions)
// c.curSequences = batch.Sequences
// c.curPositions = batch.Positions
// c.opts.Except = nil
// var locs []int32
// if !reserve {
// c.updateSlidingWindow()
// var err error
// locs, err = c.findLocs()
// if err != nil {
// return err
// }
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
// for i, pos := range batch.Positions {
// seq := batch.Sequences[i]
// loc := int(locs[i])
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// seqRange = newRange()
// }
// seqRange.min = min(seqRange.min, loc)
// c.curCellRange.min = min(c.curCellRange.min, loc)
// seqRange.max = max(seqRange.max, loc)
// c.curCellRange.max = max(c.curCellRange.max, loc)
// c.cellRanges[seq] = seqRange
// }
// } else {
// // If we are reserving memory, don't update any of the cache metadata but set the size
// // to the worst case.
// locs = make([]int32, c.curBatchSize)
// for i := range locs {
// locs[i] = int32(i)
// }
// c.curCellRange.min = 0
// c.curCellRange.max = len(c.cells) - 1
// }
// // XXX Building up the locs for what's already processed (if any)
// dummyLocs := []int{}
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// } else {
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
// dummyLocs = append(dummyLocs, i)
// }
// }
// }
// }
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
// slog.Info("XXX Causal.StartForward", "locs", locs)
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
// c.curMask = c.buildMask(ctx)
// return nil
// }
// func newRange() cellRange {
// return cellRange{
// min: math.MaxInt,
// max: 0,
// }
// }
// // Returns a slice of locations where each token in the batch should be stored
// func (c *Causal) findLocs() ([]int32, error) {
// loc := make([]int32, 0, c.curBatchSize)
// for i := range c.cells {
// if len(c.cells[i].sequences) == 0 {
// loc = append(loc, int32(i))
// if len(loc) >= c.curBatchSize {
// return loc, nil
// }
// }
// }
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
// }
// func (c *Causal) updateSlidingWindow() {
// c.curCellRange = newRange()
// if c.swaMemorySize == math.MaxInt32 {
// for _, seq := range c.curSequences {
// if seqRange, ok := c.cellRanges[seq]; ok {
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
// }
// }
// return
// }
// type lowestPosition struct {
// pos int32
// curBatch bool
// }
// // create a map of unique sequences to the lowest position in that sequence
// lowestPos := make(map[int]lowestPosition)
// for i := range c.curPositions {
// seq := c.curSequences[i]
// lowest, ok := lowestPos[seq]
// if !ok {
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
// } else if c.curPositions[i] < lowest.pos {
// lowest.pos = c.curPositions[i]
// }
// lowestPos[seq] = lowest
// }
// // for any sequences are not part of this batch, clean up any tokens
// // that are no longer needed after the processing of the previous
// // batch
// for seq, seqRange := range c.cellRanges {
// if _, ok := lowestPos[seq]; !ok {
// var last int32
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// last = max(last, c.cells[i].pos)
// }
// }
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
// }
// }
// // delete any entries that are beyond the window of the oldest position in the sequence
// for seq, lowest := range lowestPos {
// oldRange, ok := c.cellRanges[seq]
// if !ok {
// continue
// }
// newRange := newRange()
// for i := oldRange.min; i <= oldRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// newRange.min = min(newRange.min, i)
// newRange.max = max(newRange.max, i)
// }
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
// c.curCellRange.min = min(c.curCellRange.min, i)
// c.curCellRange.max = max(c.curCellRange.max, i)
// }
// }
// }
// c.cellRanges[seq] = newRange
// }
// }
// func roundDown(length, pad int) int {
// return (length / pad) * pad
// }
// func roundUp(length, pad int) int {
// return ((length + pad - 1) / pad) * pad
// }
// // Builds a mask of history x batch indicating whether for each token in the batch the
// // token in the history should apply. This is based on both the sequence and causality (the
// // position of the history is not ahead of the token in the batch).
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// // Align and pad the two dimensions as required by the backend
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// length := c.curCellRange.max - c.curCellRange.min + 1
// mask := make([]float32, batchSize*length)
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// }
// }
// }
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
// // has already been masked out because the sequence doesn't match.
// for i := c.curBatchSize * length; i < len(mask); i++ {
// mask[i] = float32(math.Inf(-1))
// }
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
// // if c.config.MaskDType != ml.DTypeFloat32 {
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
// // }
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
// return maskTensor
// }
// func (c *Causal) SetLayer(layer int) {
// c.curLayer = layer
// }
// type CausalOptions struct {
// // Enabled controls whether the causal mask is generated for a particular index in a batch
// Except []int
// }
// // SetCausal disables causal mask generation for a particular range of indicies in
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
// if !slices.Equal(c.opts.Except, opts.Except) {
// c.opts = opts
// if ctx != nil {
// c.curMask = c.buildMask(ctx)
// }
// }
// }
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// key := c.keys[c.curLayer]
// value := c.values[c.curLayer]
// kHeadDim := c.kHeadDims[c.curLayer]
// vHeadDim := c.vHeadDims[c.curLayer]
// numKVHeads := c.numKVHeads[c.curLayer]
// // rowSize := numKVHeads * c.curBatchSize
// // cachedSize := c.curMask.Dim(1)
// cachedSize := c.curLoc.Dim(0)
// // kCellSize := kHeadDim * numKVHeads
// // vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get full cache", "key", key)
// slog.Info("XXX Causal.Get full cache", "value", value)
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
// // panic("XXX")
// // fmt.Fprintln(os.Stderr, key.ToString())
// // panic("full cache value")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not converted
// // vHeadDim := value.Dim(1)
// // elemSize := value.Stride(2)
// // value = value.AsStrided(ctx,
// // []int{numKVHeads, vHeadDim, cachedSize},
// // []int{value.Stride(0), value.Stride(1)},
// // elemSize*c.curCellRange.min,
// // )
// // } else {
// // vHeadDim := c.vHeadDims[c.curLayer]
// // rowSize := value.Stride(2)
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
// // panic("XXX")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
// // panic("XXX")
// // }
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
// // // the 1 becomes trailing and messes up later operations
// // // This isn't the right solution, but works around it...
// // if c.curMask.Dim(1) == 1 {
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
// // }
// // fmt.Fprintln(os.Stderr, key.ToString())
// // fmt.Fprintln(os.Stderr, value.ToString())
// // panic("XXX")
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
// return key, value, c.curMask
// }
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
// kHeadDim := key.Dim(3)
// vHeadDim := value.Dim(3)
// numKVHeads := key.Dim(1)
// batchSize := key.Dim(2)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
// // panic("XXX")
// if c.curBatchSize != batchSize {
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
// }
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
// if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
// c.kHeadDims[c.curLayer] = kHeadDim
// c.vHeadDims[c.curLayer] = vHeadDim
// c.numKVHeads[c.curLayer] = numKVHeads
// }
// if _, ok := c.values[c.curLayer]; !ok {
// // if c.config.PermutedV {
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
// // } else {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
// // }
// }
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
// // panic("XXX")
// // curLoc := 0 // TODO c.curLoc is now a tensor
// // kSize := numKVHeads * kHeadDim
// // vSize := numKVHeads * vHeadDim
// // start := []int{int(curLoc), 0}
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
// // strides := []int{1, 1}
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
// // panic("input value")
// // fmt.Fprintln(os.Stderr, t.ToString())
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not adjusted
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
// // value = value.Transpose(ctx, 2, 0, 1, 3)
// // valueCache := c.values[c.curLayer]
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
// // } else {
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
// // }
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
// // panic("XXX")
// }
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
// seqRange := newRange()
// for i := range c.cells {
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
// if slices.Contains(c.cells[i].sequences, dstSeq) {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
// }
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// c.cellRanges[dstSeq] = seqRange
// }
// func (c *Causal) CanResume(seq int, pos int32) bool {
// if c.swaMemorySize == math.MaxInt32 {
// return true
// }
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// return false
// }
// // for sliding window, check that the window of the new sequence is contained in
// // the window of what we are storing
// var first int32 = math.MaxInt32
// var last int32 = -1
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// first = min(first, c.cells[i].pos)
// last = max(last, c.cells[i].pos)
// }
// }
// if last == -1 {
// return false
// }
// posWindowStart := max(0, pos-c.swaWindowSize)
// return posWindowStart >= first && pos <= last+1
// }
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
// if c.shiftFn == nil {
// return ErrNotSupported
// }
// seqRange := c.cellRanges[seq]
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
// size := min(seqRange.max-start+1, c.maxBatch)
// offsets := make([]int32, size)
// var batchFirst, batchLast int
// batchFirst = -1
// for i := range offsets {
// cell := c.cells[start+i]
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
// offsets[i] = offset
// if batchFirst < 0 {
// batchFirst = i
// }
// batchLast = i
// }
// }
// if batchFirst < 0 {
// continue
// }
// offsets = offsets[batchFirst : batchLast+1]
// slog.Info("XXX Causal.shift creating new temporary context")
// ctx := c.backend.NewContext()
// kShift := ctx.Input().FromInts(offsets, len(offsets))
// for i, key := range c.keys {
// if key == nil {
// continue
// }
// kHeadDim := key.Dim(2)
// numKVHeads := key.Dim(1)
// rowSize := key.Stride(0)
// key = key.AsStrided(ctx,
// []int{len(offsets), numKVHeads, kHeadDim},
// []int{key.Stride(0), key.Stride(1)},
// rowSize*(start+batchFirst),
// )
// roped, err := c.shiftFn(ctx, i, key, kShift)
// if err != nil {
// ctx.Close()
// return err
// }
// ctx.Forward(roped.Copy(ctx, key))
// }
// ctx.Compute()
// ctx.Close()
// }
// return nil
// }
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
// // should return an error, which will trigger the runner to evaluate the full history and
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
// // results in use after free, so we don't do it for now.
// var offset int32
// if endIndex != math.MaxInt32 {
// offset = beginIndex - endIndex
// }
// seqRange := newRange()
// for i := range c.cells {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// if c.cells[i].pos >= endIndex {
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// return errors.New("shifting cells shared by multiple sequences not supported")
// }
// c.cells[i].pos += offset
// }
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// }
// if seqRange == newRange() {
// delete(c.cellRanges, seq)
// return nil
// }
// c.cellRanges[seq] = seqRange
// if endIndex != math.MaxInt32 {
// err := c.shift(seq, endIndex+offset, offset)
// if err != nil {
// return err
// }
// }
// return nil
// }
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

View File

@@ -1,973 +0,0 @@
package kvcache
// import (
// "fmt"
// "math"
// "slices"
// "testing"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type testCase struct {
// name string
// in []float32
// inShape []int
// seqs []int
// pos []int32
// expected []float32
// expectedShape []int
// expectedMask []float32
// }
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
// t.Helper()
// for _, permuted := range []bool{false, true} {
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
// fn(t, &testBackend{permutedV: permuted})
// })
// }
// }
// func TestStore(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// inShape: []int{2, 3, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// expectedShape: []int{2, 3, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{115, 215, 125, 225, 135, 235},
// inShape: []int{2, 3, 1},
// seqs: []int{0},
// pos: []int32{4},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
// expectedShape: []int{2, 3, 5},
// expectedMask: []float32{0, 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWA(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 6, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWASeparateBatches(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "First seq 0",
// in: []float32{1, 2},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{0, 1},
// expected: []float32{1, 2},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 0",
// in: []float32{3, 4},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{2, 3},
// expected: []float32{2, 3, 4},
// expectedShape: []int{1, 1, 3},
// expectedMask: []float32{
// 0, 0, x,
// x, 0, 0,
// },
// },
// {
// name: "First seq 1",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{0, 1},
// expected: []float32{5, 6},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 1",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{2, 3},
// expected: []float32{6, 3, 4, 7, 8},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// x, x, x, 0, 0,
// },
// },
// {
// name: "Third seq 0",
// in: []float32{9, 10},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{9, 10, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWAMemCache(1, 3, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 2, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestChunkedAttention(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewChunkedAttentionCache(2, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// testCache(
// t, backend, cache,
// []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6, 7},
// inShape: []int{1, 1, 3},
// seqs: []int{0, 0, 0},
// pos: []int32{4, 5, 6},
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
// expectedShape: []int{1, 1, 7},
// expectedMask: []float32{
// x, x, x, x, 0, x, x,
// x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, 0,
// },
// },
// {
// name: "ThirdBatch",
// in: []float32{8, 9},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{7, 8},
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
// expectedShape: []int{1, 1, 9},
// expectedMask: []float32{
// x, x, x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, x, x, 0,
// },
// },
// },
// )
// })
// }
// func TestSequences(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{2, 2},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestRemove(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// return key.Add(ctx, shift), nil
// })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err := cache.Remove(0, 1, math.MaxInt32)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveEnd",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{1, 2},
// expected: []float32{1, 5, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, 0, x, x, x,
// x, x, 0, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err = cache.Remove(0, 0, 1)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveMiddle",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{1, 2},
// expected: []float32{7, 4, 3, 4, 6, 8},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{
// 0, 0, x, x, x, x,
// 0, 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestCopy(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// cache.CopyPrefix(0, 1, 2)
// tests = []testCase{
// {
// name: "Copy",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{3, 4},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
// if err != nil {
// panic(err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats(test.in, test.inShape...)
// cache.Put(context, tensor, tensor)
// out, _, mask := cache.Get(context)
// context.Forward(out, mask).Compute(out, mask)
// if !slices.Equal(out.Floats(), test.expected) {
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
// }
// if !slices.Equal(out.Shape(), test.expectedShape) {
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
// }
// if !slices.Equal(mask.Floats(), test.expectedMask) {
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
// }
// })
// }
// }
// func TestCanResume(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// cache := NewSWACache(windowSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4},
// Sequences: []int{0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
// cache.Put(context, tensor, tensor)
// // with window size 4, nothing has slid out of the window yet
// if !cache.CanResume(0, 0) {
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
// }
// if !cache.CanResume(0, 1) {
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
// }
// if !cache.CanResume(0, 2) {
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
// }
// if !cache.CanResume(0, 3) {
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
// }
// if !cache.CanResume(0, 4) {
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
// }
// // shift window by adding position 5
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{5},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
// }
// })
// }
// func TestCanResumeSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// memSize := int32(5)
// cache := NewSWAMemCache(windowSize, memSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
// cache.Put(context, tensor, tensor)
// // shift window by adding position 7
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{7},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 6) {
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
// }
// if !cache.CanResume(0, 7) {
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
// }
// })
// }
// type testBackend struct {
// ml.Backend
// permutedV bool
// }
// func (b *testBackend) NewContext() ml.Context {
// return &testContext{}
// }
// func (b *testBackend) NewContextSize(int) ml.Context {
// return &testContext{}
// }
// func (b *testBackend) CacheConfig() ml.CacheConfig {
// return ml.CacheConfig{PermutedV: b.permutedV}
// }
// type testContext struct {
// ml.Context
// }
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// total := 0
// if len(shape) > 0 {
// total = 1
// for _, s := range shape {
// total *= s
// }
// }
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
// }
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
// return c.Empty(dtype, shape...)
// }
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
// copy(t.data, s)
// return t
// }
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
// f := make([]float32, len(s))
// for i := range f {
// f[i] = float32(s[i])
// }
// out := c.FromFloats(f, shape...)
// out.(*testTensor).dtype = ml.DTypeI32
// return out
// }
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
// s := make([]float32, 0, int((stop-start)/step))
// for i := start; i < stop; i += step {
// s = append(s, i)
// }
// out := c.FromFloats(s, len(s))
// out.(*testTensor).dtype = dtype
// return out
// }
// func (c *testContext) Input() ml.Context { return c }
// func (c *testContext) Layer(int) ml.Context { return c }
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
// func (c *testContext) Compute(...ml.Tensor) {}
// func (c *testContext) Reserve() {}
// func (c *testContext) MaxGraphNodes() int {
// return 10
// }
// func (c *testContext) Close() {}
// type testTensor struct {
// ml.Tensor
// dtype ml.DType
// elementSize int
// data []float32
// shape []int
// }
// func (t *testTensor) Dim(n int) int {
// return t.shape[n]
// }
// func (t *testTensor) Stride(n int) int {
// stride := t.elementSize
// for i := range n {
// stride *= t.shape[i]
// }
// return stride
// }
// func (t *testTensor) Shape() []int {
// return t.shape
// }
// func (t *testTensor) DType() ml.DType {
// return t.dtype
// }
// func (t *testTensor) Floats() []float32 {
// out := make([]float32, len(t.data))
// copy(out, t.data)
// return out
// }
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = -t.data[i]
// }
// return out
// }
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
// }
// return out
// }
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: t.data,
// shape: shape,
// }
// }
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
// offset /= t.elementSize
// var s []int
// switch len(shape) {
// case 1:
// s = []int{shape[0]}
// case 3:
// s = []int{shape[0], shape[2]}
// case 5:
// s = []int{shape[0], shape[2], shape[4]}
// default:
// panic("unsupported number of dimensions")
// }
// context := &testContext{}
// view := context.Empty(t.dtype, s...).(*testTensor)
// view.data = t.data[offset : offset+len(view.data)]
// return view
// }
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
// if len(t.shape) > 4 || len(order) > 4 {
// panic("permute only supports up to 4 dimensions")
// }
// if len(order) != len(t.shape) && len(order) != 4 {
// panic("invalid number of dimensions for permute")
// }
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
// orderFull := append(make([]int, 0, 4), order...)
// for len(orderFull) < 4 {
// orderFull = append(orderFull, len(orderFull))
// }
// seen := [4]bool{}
// shape4 := [4]int{1, 1, 1, 1}
// for i := 0; i < len(t.shape) && i < 4; i++ {
// shape4[i] = t.shape[i]
// }
// newShape4 := [4]int{1, 1, 1, 1}
// for axis := range 4 {
// dst := orderFull[axis]
// if dst < 0 || dst >= 4 {
// panic("invalid axis for permute")
// }
// if seen[dst] {
// panic("duplicate axis for permute")
// }
// seen[dst] = true
// newShape4[dst] = shape4[axis]
// }
// total := len(t.data)
// newData := make([]float32, total)
// if total > 0 {
// oldDims := shape4
// newDims := newShape4
// oldStride := [4]int{1, 1, 1, 1}
// newStride := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
// newStride[i] = newStride[i-1] * newDims[i-1]
// }
// var coords [4]int
// var newCoords [4]int
// for idx := range total {
// remainder := idx
// for axis := range 4 {
// dim := oldDims[axis]
// if dim == 0 {
// coords[axis] = 0
// continue
// }
// coords[axis] = remainder % dim
// remainder /= dim
// }
// for axis := range 4 {
// newCoords[orderFull[axis]] = coords[axis]
// }
// newIndex := 0
// for axis := range 4 {
// if newDims[axis] == 0 {
// continue
// }
// newIndex += newCoords[axis] * newStride[axis]
// }
// newData[newIndex] = t.data[idx]
// }
// }
// numDims := 4
// for numDims > 1 && newShape4[numDims-1] <= 1 {
// numDims--
// }
// newShape := make([]int, numDims)
// copy(newShape, newShape4[:numDims])
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: newData,
// shape: newShape,
// }
// }
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
// dst := t
// srcTensor := src.(*testTensor)
// idxTensor := idxs.(*testTensor)
// shapeTo4D := func(shape []int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 0; i < len(shape) && i < 4; i++ {
// out[i] = shape[i]
// }
// return out
// }
// computeStrides := func(shape [4]int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// out[i] = out[i-1] * shape[i-1]
// }
// return out
// }
// dstShape4D := shapeTo4D(dst.shape)
// srcShape4D := shapeTo4D(srcTensor.shape)
// idxShape4D := shapeTo4D(idxTensor.shape)
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
// panic("SetRows requires matching tensor shapes")
// }
// if srcShape4D[1] != idxShape4D[0] {
// panic("SetRows rows/index mismatch")
// }
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
// panic("SetRows cannot broadcast indices")
// }
// if idxShape4D[3] != 1 {
// panic("SetRows expects 1D or 2D index tensors")
// }
// dstStride := computeStrides(dstShape4D)
// srcStride := computeStrides(srcShape4D)
// idxStride := computeStrides(idxShape4D)
// numColumns := srcShape4D[0]
// numRows := srcShape4D[1]
// for dim3Index := range dstShape4D[3] {
// for dim2Index := range dstShape4D[2] {
// idxDim2 := 0
// idxDim3 := 0
// if idxShape4D[1] > 0 {
// idxDim2 = dim2Index % idxShape4D[1]
// }
// if idxShape4D[2] > 0 {
// idxDim3 = dim3Index % idxShape4D[2]
// }
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
// for row := range numRows {
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
// if idx < 0 || idx >= dstShape4D[1] {
// panic("SetRows index out of range")
// }
// srcOffset := srcBase + row*srcStride[1]
// dstOffset := dstBase + idx*dstStride[1]
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
// }
// }
// }
// return dst
// }
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// copy(t2.(*testTensor).data, t.data)
// return nil
// }

View File

@@ -1,144 +0,0 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type MLXCausal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewMLXCausalCache() *MLXCausal {
return &MLXCausal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
func (c *MLXCausal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *MLXCausal) Close() {
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX MLXCausal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

134
x/mlxrunner/imagegen.go Normal file
View File

@@ -0,0 +1,134 @@
//go:build mlx
package mlxrunner
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models.
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
var imageGenMu sync.Mutex
// loadImageModel loads an image generation model.
func (s *server) loadImageModel() error {
// Check memory requirements before loading
var requiredMemory uint64
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(manifest.TotalTensorSize())
}
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(s.modelName)
slog.Info("detected image model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(s.modelName); err != nil {
return fmt.Errorf("failed to load flux2 model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(s.modelName); err != nil {
return fmt.Errorf("failed to load zimage model: %w", err)
}
model = m
}
s.imageModel = model
return nil
}
// handleImageCompletion handles image generation requests.
func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
// Serialize generation requests - MLX model may not handle concurrent generation
imageGenMu.Lock()
defer imageGenMu.Unlock()
// Set seed if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
ctx := r.Context()
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
// Generate image
img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

420
x/mlxrunner/llm.go Normal file
View File

@@ -0,0 +1,420 @@
//go:build mlx
package mlxrunner
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextModel is the interface for LLM text generation models.
type TextModel interface {
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
NewCache(maxSeqLen int32) []cache.Cache
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
MaxContextLength() int32
NumLayers() int
}
// llmState holds the state for LLM generation
type llmState struct {
model TextModel
}
var llmMu sync.Mutex
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
// Lazy initialization of generationStream
if generationStream == nil {
generationStream = mlx.NewStream()
}
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
// Decoder wraps model + cache for autoregressive generation.
// This matches the pattern from cmd/engine/generate.go
type Decoder struct {
model TextModel
caches []cache.Cache
vocabSize int32
temp float32
token *mlx.Array // Current token (kept across iterations)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
}
func NewDecoder(m TextModel, temp float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// Process all-but-1 tokens in chunks, eval cache state for memory management
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
logits := d.model.Forward(x, d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
// sample samples from logits using temperature scaling
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
lastLogits = mlx.Reshape(lastLogits, vocabSize)
if temp <= 0 || temp < 0.01 {
// Greedy decoding
return mlx.Argmax(lastLogits, -1, false)
}
// Apply temperature scaling
scaled := mlx.DivScalar(lastLogits, temp)
return mlx.RandomCategorical(scaled, -1, 1)
}
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
manifest, err := imagegen.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
var modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(configData, &modelConfig); err != nil {
return fmt.Errorf("failed to parse config.json: %w", err)
}
arch := ""
if len(modelConfig.Architectures) > 0 {
arch = modelConfig.Architectures[0]
}
if arch == "" {
arch = modelConfig.ModelType
}
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
// Load the appropriate model based on architecture
var model TextModel
archLower := strings.ToLower(arch)
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(manifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}
model = m
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
default:
return fmt.Errorf("LLM architecture %q is not yet supported. "+
"Supported architectures: glm4-moe-lite. "+
"Please convert your model to GGUF format or use a supported architecture", arch)
}
s.llmModel = &llmState{
model: model,
}
return nil
}
// handleLLMCompletion handles LLM text generation requests.
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
// Serialize generation requests
llmMu.Lock()
defer llmMu.Unlock()
if err := s.llmGenerate(w, r, req); err != nil {
slog.Error("LLM generation failed", "error", err)
// Don't send error if we've already started streaming
}
}
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
state := s.llmModel
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("streaming not supported")
}
tok := state.model.Tokenizer()
// The prompt is already formatted by the server using the model's renderer
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
prompt := req.Prompt
// Tokenize the prompt
inputIDs := tok.Encode(prompt, true)
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
// Generation parameters
maxTokens := int(state.model.MaxContextLength())
if maxTokens <= 0 {
maxTokens = 4096
}
if req.Options != nil && req.Options.NumPredict > 0 {
maxTokens = req.Options.NumPredict
}
temperature := float32(0.7)
if req.Options != nil && req.Options.Temperature > 0 {
temperature = float32(req.Options.Temperature)
}
// Enable MLX compilation for better performance
mlx.EnableCompile()
// Create decoder with fresh caches
dec := NewDecoder(state.model, temperature)
prefillStart := time.Now()
prefillTokens := dec.prefill(inputIDs)
// Prefill measurement includes time to first token
firstToken := dec.step()
prefillDuration := time.Since(prefillStart)
promptEvalDuration := prefillDuration
enc := json.NewEncoder(w)
ctx := r.Context()
generated := 0
stopReason := "max_tokens"
// Handle first token
generated++
if tok.IsEOS(firstToken) {
resp := Response{
Done: true,
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}
text := tok.Decode([]int32{firstToken})
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
genStart := time.Now()
// Generation loop
for n := 1; n < maxTokens; n++ {
// Check for cancellation
select {
case <-ctx.Done():
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
break
default:
}
if stopReason != "max_tokens" {
break
}
token := dec.step()
generated++
if tok.IsEOS(token) {
stopReason = fmt.Sprintf("eos_token:%d", token)
break
}
text := tok.Decode([]int32{token})
// Check for stop sequences
if req.Options != nil && len(req.Options.Stop) > 0 {
shouldStop := false
var matchedStop string
for _, stop := range req.Options.Stop {
if strings.Contains(text, stop) {
text = strings.Split(text, stop)[0]
shouldStop = true
matchedStop = stop
break
}
}
if shouldStop {
if text != "" {
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
}
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
break
}
}
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
// Periodically clear MLX cache
if n%256 == 0 {
mlx.ClearCache()
}
}
// Clean up
mlx.ClearCache()
// Send final response with stats
evalDuration := time.Since(genStart)
resp = Response{
Done: true,
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
EvalCount: generated,
EvalDuration: int(evalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}

204
x/mlxrunner/runner.go Normal file
View File

@@ -0,0 +1,204 @@
//go:build mlx
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
package mlxrunner
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// Execute is the entry point for the unified MLX runner subprocess.
func Execute(args []string) error {
// Set up logging with appropriate level from environment
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
// Initialize MLX
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
// Detect model type from capabilities
mode := detectModelMode(*modelName)
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
// Create and start server
server, err := newServer(*modelName, *port, mode)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
// LLM-specific endpoints
if mode == ModeLLM {
mux.HandleFunc("/tokenize", server.tokenizeHandler)
mux.HandleFunc("/embedding", server.embeddingHandler)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("mlx runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := imagegen.DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
return ModeImageGen
}
}
// Default to LLM mode for safetensors models without known image gen types
return ModeLLM
}
// server holds the model and handles HTTP requests.
type server struct {
mode ModelMode
modelName string
port int
// Image generation model (when mode == ModeImageGen)
imageModel ImageModel
// LLM model (when mode == ModeLLM)
llmModel *llmState
}
// newServer creates a new server instance and loads the appropriate model.
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
s := &server{
mode: mode,
modelName: modelName,
port: port,
}
switch mode {
case ModeImageGen:
if err := s.loadImageModel(); err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
case ModeLLM:
if err := s.loadLLMModel(); err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
}
return s, nil
}
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := HealthResponse{Status: "ok"}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch s.mode {
case ModeImageGen:
s.handleImageCompletion(w, r, req)
case ModeLLM:
s.handleLLMCompletion(w, r, req)
}
}
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
var req struct {
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
tok := s.llmModel.model.Tokenizer()
tokens := tok.Encode(req.Content, false)
// Convert int32 to int for JSON response
intTokens := make([]int, len(tokens))
for i, t := range tokens {
intTokens[i] = int(t)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
}
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
}

View File

@@ -1,10 +1,10 @@
//go:build !mlx
package runner
package mlxrunner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
return errors.New("image generation not available: build with mlx tag")
return errors.New("MLX runner not available: build with mlx tag")
}

View File

@@ -1,4 +1,4 @@
package imagegen
package mlxrunner
import (
"bufio"
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
@@ -22,19 +23,19 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
)
// Server wraps an image generation subprocess to implement llm.LlamaServer.
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
mode ModelMode
vramSize uint64
done chan error
client *http.Client
@@ -42,10 +43,10 @@ type Server struct {
lastErrLock sync.Mutex
}
// NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) {
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
@@ -70,8 +71,8 @@ func NewServer(modelName string) (*Server, error) {
exe = eval
}
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -104,11 +105,21 @@ func NewServer(modelName string) (*Server, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
vramSize = uint64(manifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
vramSize: EstimateVRAM(modelName),
mode: mode,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
@@ -119,23 +130,23 @@ func NewServer(modelName string) (*Server, error) {
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("image-runner", "msg", scanner.Text())
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
@@ -158,6 +169,7 @@ func (s *Server) ModelPath() string {
return s.modelName
}
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -193,18 +205,18 @@ func (s *Server) waitUntilRunning() error {
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for image runner to start")
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("image runner is ready", "port", s.port)
slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
@@ -218,27 +230,43 @@ func (s *Server) getLastErr() string {
return s.lastErr
}
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
// WaitUntilRunning satisfies the LlamaServer interface.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
// Extract raw image bytes from llm.ImageData slice
var images [][]byte
for _, img := range req.Images {
images = append(images, img.Data)
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Steps: int(req.Steps),
Seed: seed,
Images: images,
}
// Pass LLM options if present
if req.Options != nil {
creq.Options = &RequestOptions{
NumPredict: req.Options.NumPredict,
Temperature: float64(req.Options.Temperature),
TopP: float64(req.Options.TopP),
TopK: req.Options.TopK,
Stop: req.Options.Stop,
}
}
body, err := json.Marshal(creq)
@@ -260,31 +288,47 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("request failed: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
// Parse subprocess response (has singular "image" field)
// Parse subprocess response
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
// Log stop reason when generation completes
if raw.Done && raw.StopReason != "" {
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
}
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
@@ -293,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
}
return scanner.Err()
// Scanner exited without receiving Done - connection was likely closed
scanErr := scanner.Err()
if scanErr != nil {
slog.Error("mlx scanner error", "error", scanErr)
} else {
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
}
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
}
return scanErr
}
// Close terminates the subprocess.
@@ -302,7 +359,7 @@ func (s *Server) Close() error {
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
@@ -331,18 +388,51 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported")
return nil, 0, errors.New("embeddings not supported for MLX models")
}
// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("not supported")
body, err := json.Marshal(map[string]string{"content": content})
if err != nil {
return nil, err
}
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
}
var result struct {
Tokens []int `json:"tokens"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result.Tokens, nil
}
// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported")
return "", errors.New("detokenization not supported for MLX models")
}
// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
@@ -352,9 +442,17 @@ func (s *Server) Pid() int {
return -1
}
func (s *Server) GetPort() int { return s.port }
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
// GetPort returns the port the subprocess is listening on.
func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns device information.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:

81
x/mlxrunner/types.go Normal file
View File

@@ -0,0 +1,81 @@
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
//
// This package handles safetensors models created with `ollama create --experimental`,
// supporting both text generation (LLM) and image generation (diffusion) models
// through a single unified interface.
package mlxrunner
// Request is the request format for completion requests.
type Request struct {
Prompt string `json:"prompt"`
// LLM-specific fields
Options *RequestOptions `json:"options,omitempty"`
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
}
// RequestOptions contains LLM-specific generation options.
type RequestOptions struct {
NumPredict int `json:"num_predict,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop []string `json:"stop,omitempty"`
}
// Response is streamed back for each progress update.
type Response struct {
// Text generation response
Content string `json:"content,omitempty"`
// Image generation response
Image string `json:"image,omitempty"` // Base64-encoded PNG
// Common fields
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
// Progress fields
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
// Statistics
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
// HealthResponse is returned by the health endpoint.
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress,omitempty"`
}
// ModelMode represents the type of model being run.
type ModelMode int
const (
// ModeLLM indicates a text generation model.
ModeLLM ModelMode = iota
// ModeImageGen indicates an image generation model.
ModeImageGen
)
func (m ModelMode) String() string {
switch m {
case ModeLLM:
return "llm"
case ModeImageGen:
return "imagegen"
default:
return "unknown"
}
}

View File

@@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) {
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewMLXCausalCache()
m.Cache = kvcache.NewCausalCache()
return &m, nil
}

View File

@@ -9,7 +9,8 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
// modelConfig represents the HuggingFace config.json structure
@@ -35,22 +36,22 @@ type modelConfig struct {
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
totalBytes += layer.Size
tensorCount++
}
@@ -151,27 +152,30 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
return getTensorInfoFromManifest(manifest)
return getTensorInfoFromManifest(mf)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
// Read the safetensors header from the blob
blobPath := manifest.BlobPath(layer.Digest)
blobPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
continue
}
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
@@ -195,29 +199,51 @@ func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor,
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
// Reads from model_index.json first, falls back to detection from tensor names.
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check if model is quantized by looking for _scale tensors
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
// First try to read quantization from model_index.json
var modelIndex struct {
Quantization string `json:"quantization"`
}
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
return modelIndex.Quantization, nil
}
// Fallback: detect from tensor names
hasScales := false
hasQBias := false
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if strings.HasSuffix(layer.Name, "_scale") {
// Model is quantized - return FP8 (affine quantization)
return "FP8", nil
hasScales = true
}
if strings.HasSuffix(layer.Name, "_qbias") {
hasQBias = true
}
}
}
if hasScales {
if hasQBias {
// Affine mode (has scale + qbias) - could be FP4 or FP8
// Default to FP4 as it's more common
return "FP4", nil
}
// No qbias = NVFP4
return "NVFP4", nil
}
// Not quantized - return torch_dtype from config.json
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/manifest"
)
func TestBuildModelInfo(t *testing.T) {
@@ -451,8 +451,14 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
}
func TestGetTensorInfoFromManifest(t *testing.T) {
// Create a temp directory for blobs
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Create test tensor blobs
tensors := []struct {
@@ -463,26 +469,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}{
{
name: "model.embed_tokens.weight",
digest: "sha256:abc123",
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
digest: "sha256:def456",
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
digest: "sha256:ghi789",
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
dtype: "F32",
shape: []int64{2560},
},
}
// Create blob files
var layers []imagegen.ManifestLayer
var layers []manifest.Layer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
@@ -498,15 +504,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
// Write blob file
blobName := "sha256-" + tensor.digest[7:]
blobPath := filepath.Join(tempDir, blobName)
// Write blob file using the digest format expected by GetBlobsPath
blobPath, err := manifest.BlobsPath(tensor.digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.tensor",
layers = append(layers, manifest.Layer{
MediaType: manifest.MediaTypeImageTensor,
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
@@ -514,21 +522,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}
// Add a non-tensor layer (should be skipped)
layers = append(layers, imagegen.ManifestLayer{
layers = append(layers, manifest.Layer{
MediaType: "application/vnd.ollama.image.json",
Digest: "sha256:config",
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
Size: 100,
Name: "config.json",
})
manifest := &imagegen.ModelManifest{
Manifest: &imagegen.Manifest{
Layers: layers,
},
BlobDir: tempDir,
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: layers,
}
result, err := getTensorInfoFromManifest(manifest)
result, err := getTensorInfoFromManifest(mf)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}

51
x/server/thinking.go Normal file
View File

@@ -0,0 +1,51 @@
package server
import (
"strings"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
// IsSafetensorsThinkingModel checks if a safetensors model supports thinking
// based on its architecture from config.json.
func IsSafetensorsThinkingModel(name model.Name) bool {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return false
}
var config struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
return false
}
// Determine architecture
arch := config.ModelType
if arch == "" && len(config.Architectures) > 0 {
arch = config.Architectures[0]
}
if arch == "" {
return false
}
archLower := strings.ToLower(arch)
// List of architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
return true
}
}
return false
}