mirror of
https://github.com/ollama/ollama.git
synced 2026-01-23 15:02:09 -05:00
Compare commits
10 Commits
v0.14.3-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771d9280ec | ||
|
|
862bc0a3bf | ||
|
|
c01608b6a1 | ||
|
|
199c41e16e | ||
|
|
3b3bf6c217 | ||
|
|
f52c21f457 | ||
|
|
b5d0f72f16 | ||
|
|
148a1be0a3 | ||
|
|
d6dd430abd | ||
|
|
ae78112c50 |
@@ -35,6 +35,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -2026,6 +2027,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.ConfigCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
36
cmd/config/claude.go
Normal file
36
cmd/config/claude.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command("claude", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
42
cmd/config/claude_test.go
Normal file
42
cmd/config/claude_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
61
cmd/config/codex.go
Normal file
61
cmd/config/codex.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
cmd/config/codex_test.go
Normal file
28
cmd/config/codex_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
115
cmd/config/config.go
Normal file
115
cmd/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Package config provides integration configuration for external coding tools
|
||||
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
cfg.Integrations = make(map[string]*integration)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func save(cfg *config) error {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
373
cmd/config/config_test.go
Normal file
373
cmd/config/config_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(config.Models) != len(models) {
|
||||
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||
}
|
||||
for i, m := range models {
|
||||
if config.Models[i] != m {
|
||||
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||
config := &integration{Models: []string{}}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "" {
|
||||
t.Errorf("expected empty string, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-x" {
|
||||
t.Errorf("expected model-x, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
defaultModel1 = config1.Models[0]
|
||||
}
|
||||
defaultModel2 := ""
|
||||
if len(config2.Models) > 0 {
|
||||
defaultModel2 = config2.Models[0]
|
||||
}
|
||||
if defaultModel1 != "model-1" {
|
||||
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||
}
|
||||
if defaultModel2 != "model-2" {
|
||||
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIntegrations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 0 {
|
||||
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
if config.Models == nil {
|
||||
// nil is acceptable
|
||||
} else if len(config.Models) != 0 {
|
||||
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected non-nil config")
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
t.Error("expected non-nil Integrations map")
|
||||
}
|
||||
if len(cfg.Integrations) != 0 {
|
||||
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loads existing config", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["test"] == nil {
|
||||
t.Fatal("expected test integration")
|
||||
}
|
||||
if len(cfg.Integrations["test"].Models) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||
|
||||
_, err := load()
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupted JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("creates config file", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"test": {Models: []string{"model-a", "model-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
path, _ := configPath()
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("config file was not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||
"codex": {Models: []string{"qwen2.5"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(loaded.Integrations) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||
}
|
||||
if loaded.Integrations["claude"] == nil {
|
||||
t.Error("missing claude integration")
|
||||
}
|
||||
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
184
cmd/config/droid.go
Normal file
184
cmd/config/droid.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||
type droidSettings struct {
|
||||
CustomModels []modelEntry `json:"customModels"`
|
||||
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||
}
|
||||
|
||||
type sessionSettings struct {
|
||||
Model string `json:"model"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
type modelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read file once, unmarshal twice:
|
||||
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||
settingsMap := make(map[string]any)
|
||||
var settings droidSettings
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||
for _, raw := range rawModels {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
if m["apiKey"] != "ollama" {
|
||||
nonOllamaModels = append(nonOllamaModels, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||
|
||||
// Update session default settings (preserve unknown fields in the nested object)
|
||||
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings droidSettings
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, m := range settings.CustomModels {
|
||||
if m.APIKey == "ollama" {
|
||||
result = append(result, m.Model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
1302
cmd/config/droid_test.go
Normal file
1302
cmd/config/droid_test.go
Normal file
File diff suppressed because it is too large
Load Diff
99
cmd/config/files.go
Normal file
99
cmd/config/files.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
502
cmd/config/files_test.go
Normal file
502
cmd/config/files_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := backupToTmp(path)
|
||||
if err != nil {
|
||||
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
362
cmd/config/integrations.go
Normal file
362
cmd/config/integrations.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Runners execute the launching of a model with the integration - claude, codex
|
||||
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
||||
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files (supports multi-model selection)
|
||||
type Editor interface {
|
||||
// Paths returns the paths to the config files for the integration
|
||||
Paths() []string
|
||||
// Edit updates the config files for the integration with the given models
|
||||
Edit(models []string) error
|
||||
// Models returns the models currently configured for the integration
|
||||
// TODO(parthsareen): add error return to Models()
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, selectItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// ConfigCmd returns the cobra command for configuring integrations.
|
||||
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var launchFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "config [INTEGRATION]",
|
||||
Short: "Configure an external integration to use Ollama",
|
||||
Long: `Configure an external application to use Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
|
||||
Examples:
|
||||
ollama config
|
||||
ollama config claude
|
||||
ollama config droid --launch`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
} else {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r, ok := integrations[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If --launch without --model, use saved config if available
|
||||
if launchFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
if m != modelFlag {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, isEditor := r.(Editor); isEditor {
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
||||
}
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(models, func(m string) bool {
|
||||
return !strings.HasSuffix(m, "cloud")
|
||||
}) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
|
||||
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
|
||||
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
|
||||
}
|
||||
|
||||
if launchFlag {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
|
||||
return cmd
|
||||
}
|
||||
188
cmd/config/integrations_test.go
Normal file
188
cmd/config/integrations_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFound bool
|
||||
wantName string
|
||||
}{
|
||||
{"claude lowercase", "claude", true, "Claude Code"},
|
||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||
{"codex", "codex", true, "Codex"},
|
||||
{"droid", "droid", true, "Droid"},
|
||||
{"opencode", "opencode", true, "OpenCode"},
|
||||
{"unknown integration", "unknown", false, ""},
|
||||
{"empty string", "", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, found := integrations[strings.ToLower(tt.input)]
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
||||
}
|
||||
if found && r.String() != tt.wantName {
|
||||
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
t.Fatalf("integration %q not found in registry", name)
|
||||
}
|
||||
if r.String() == "" {
|
||||
t.Error("integration.String() should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{"empty list", []string{}, false},
|
||||
{"single local model", []string{"llama3.2"}, true},
|
||||
{"single cloud model", []string{"cloud-model"}, false},
|
||||
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
||||
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
||||
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
||||
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := ConfigCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "config [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
launchFlag := cmd.Flags().Lookup("launch")
|
||||
if launchFlag == nil {
|
||||
t.Error("--launch flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
reason string
|
||||
}{
|
||||
{"empty list", []string{}, false, "empty list has no local models"},
|
||||
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
||||
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
||||
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
||||
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
||||
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
||||
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := ConfigCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("ConfigCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
203
cmd/config/opencode.go
Normal file
203
cmd/config/opencode.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if displayName, ok := cfgMap["name"].(string); ok {
|
||||
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
models[model] = map[string]any{
|
||||
"name": fmt.Sprintf("%s [Ollama]", model),
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
437
cmd/config/opencode_test.go
Normal file
437
cmd/config/opencode_test.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
os.RemoveAll(stateDir)
|
||||
}
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
if provider["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("update existing model", func(t *testing.T) {
|
||||
cleanup()
|
||||
o.Edit([]string{"llama3.2"})
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["keybindings"] == nil {
|
||||
t.Error("keybindings was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||
})
|
||||
|
||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
if len(state["favorite"].([]any)) != 1 {
|
||||
t.Error("favorite was modified")
|
||||
}
|
||||
if state["variant"].(map[string]any)["a"] != "b" {
|
||||
t.Error("variant was modified")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
recent := state["recent"].([]any)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("remove model", func(t *testing.T) {
|
||||
cleanup()
|
||||
// First add two models
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
// Add a non-Ollama model manually
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||
})
|
||||
}
|
||||
|
||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider not found")
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("ollama provider not found")
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("models not found")
|
||||
}
|
||||
if models[model] == nil {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
return // No provider means no model
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
return // No ollama means no model
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
return // No models means no model
|
||||
}
|
||||
if models[model] != nil {
|
||||
t.Errorf("model %s should not exist but was found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("recent not found")
|
||||
}
|
||||
if index >= len(recent) {
|
||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||
}
|
||||
entry, ok := recent[index].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("entry is not a map")
|
||||
}
|
||||
if entry["providerID"] != providerID {
|
||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||
}
|
||||
if entry["modelID"] != modelID {
|
||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case tests for opencode.go
|
||||
|
||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Should not panic - corrupted JSON should be treated as empty
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid JSON was created
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid state was created
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider is now correct type
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||
}
|
||||
if provider["ollama"] == nil {
|
||||
t.Error("ollama provider should be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||
}
|
||||
|
||||
// The function should handle this gracefully
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
|
||||
// recent should be properly set after setup
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||
} else if len(recent) == 0 {
|
||||
t.Logf("Note: recent is empty (documenting behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := o.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Model name with special characters (though unusual)
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit([]string{specialModel})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored correctly
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Model should be accessible
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := o.Models()
|
||||
if len(models) > 0 {
|
||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||
}
|
||||
}
|
||||
499
cmd/config/selector.go
Normal file
499
cmd/config/selector.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
913
cmd/config/selector_test.go
Normal file
913
cmd/config/selector_test.go
Normal file
@@ -0,0 +1,913 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
@@ -143,6 +144,7 @@ var (
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"lfm2.5-thinking",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
@@ -263,6 +265,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -14,7 +14,7 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
Status string `json:"-"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
blobs, err := BlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
status: fmt.Sprintf("%s %s", status, digest),
|
||||
Status: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
status: fmt.Sprintf("using existing layer %s", digest),
|
||||
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
blobPath, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
95
manifest/paths.go
Normal file
95
manifest/paths.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
func Path() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathForName returns the path to the manifest file for a specific model name.
|
||||
func PathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func BlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PruneDirectory removes empty directories recursively.
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageEditRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Image == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||
return
|
||||
}
|
||||
|
||||
genReq, err := openai.FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageEditsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
// Base64-encoded test image (1x1 pixel PNG)
|
||||
testImage := ""
|
||||
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image edit basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing model",
|
||||
body: `{
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing image",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "image is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image string `json:"image"` // Base64-encoded image data
|
||||
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
|
||||
// Decode the input image
|
||||
if r.Image != "" {
|
||||
imgData, err := decodeImageURL(r.Image)
|
||||
if err != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||
}
|
||||
req.Images = append(req.Images, imgData)
|
||||
}
|
||||
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "make it blue" {
|
||||
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if len(result.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Size: "512x768",
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Width != 512 {
|
||||
t.Errorf("expected width 512, got %d", result.Width)
|
||||
}
|
||||
|
||||
if result.Height != 768 {
|
||||
t.Errorf("expected height 768, got %d", result.Height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||
seed := int64(12345)
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Seed: &seed,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options == nil {
|
||||
t.Fatal("expected options to be set")
|
||||
}
|
||||
|
||||
if result.Options["seed"] != seed {
|
||||
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: "not-valid-base64",
|
||||
}
|
||||
|
||||
_, err := FromImageEditRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid image")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
var currentLineBuf []rune
|
||||
|
||||
// draining tracks if we're processing buffered input from cooked mode.
|
||||
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
|
||||
// We treat \n from cooked mode as submit, not multiline.
|
||||
// We check Buffered() after the first read since the bufio buffer is
|
||||
// empty until then. This is compatible with """ multiline mode in
|
||||
// interactive.go since each Readline() call is independent.
|
||||
var draining, stopDraining bool
|
||||
|
||||
for {
|
||||
// Apply deferred state change from previous iteration
|
||||
if stopDraining {
|
||||
draining = false
|
||||
stopDraining = false
|
||||
}
|
||||
|
||||
// don't show placeholder when pasting unless we're in multiline mode
|
||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||
if buf.IsEmpty() && showPlaceholder {
|
||||
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
r, err := i.Terminal.Read()
|
||||
|
||||
// After reading, check if there's more buffered data. If so, we're
|
||||
// processing cooked-mode input. Once buffer empties, the current
|
||||
// char is the last buffered one (still drain it), then stop next iteration.
|
||||
if i.Terminal.reader.Buffered() > 0 {
|
||||
draining = true
|
||||
} else if draining {
|
||||
stopDraining = true
|
||||
}
|
||||
|
||||
if buf.IsEmpty() {
|
||||
fmt.Print(ClearToEOL)
|
||||
}
|
||||
@@ -232,15 +255,20 @@ func (i *Instance) Readline() (string, error) {
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
// If not draining cooked-mode input, treat as multiline
|
||||
if !draining {
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
// Draining cooked-mode input: treat \n as submit
|
||||
fallthrough
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := GetBlobsPath(files[fn])
|
||||
blobPath, err := manifest.BlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(t, mediaType)
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
var layers []manifest.Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := GetBlobsPath(layer.Digest)
|
||||
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
n model.Name
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
@@ -465,10 +467,10 @@ type downloadOpts struct {
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
fp, err := manifest.BlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
|
||||
211
server/images.go
211
server/images.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -75,12 +75,6 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
@@ -274,44 +268,22 @@ func (m *Model) String() string {
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
n := model.ParseName(name)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
m := &Model{
|
||||
Name: n.String(),
|
||||
ShortName: n.DisplayShortest(),
|
||||
Digest: mf.Digest(),
|
||||
Template: template.DefaultTemplate,
|
||||
}
|
||||
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if mf.Config.Digest != "" {
|
||||
filename, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -322,29 +294,29 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
for _, layer := range mf.Layers {
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -352,7 +324,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -362,7 +334,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.System = string(bts)
|
||||
m.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -371,7 +343,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -381,7 +353,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -389,11 +361,11 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.License = append(model.License, string(bts))
|
||||
m.License = append(m.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
@@ -408,7 +380,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -437,7 +409,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||
manifests, err := Manifests(true)
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -452,7 +424,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k := range deleteMap {
|
||||
fp, err := GetBlobsPath(k)
|
||||
fp, err := manifest.BlobsPath(k)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
@@ -468,7 +440,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
func PruneLayers() error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
p, err := GetBlobsPath("")
|
||||
p, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,9 +455,9 @@ func PruneLayers() error {
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := GetBlobsPath(name)
|
||||
_, err := manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||
@@ -510,63 +482,30 @@ func PruneLayers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
manifestPath, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -574,7 +513,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -582,17 +521,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -611,44 +550,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
manifest, _, err := GetManifest(mp)
|
||||
existingMf, err := manifest.ParseNamedManifest(n)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||
} else {
|
||||
for _, l := range manifest.Layers {
|
||||
for _, l := range existingMf.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
if manifest.Config.Digest != "" {
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -658,7 +597,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
n: n,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
@@ -677,7 +616,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
fp, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -692,16 +631,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
delete(deleteMap, mf.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -728,9 +667,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -738,7 +677,7 @@ func hasTensorLayers(layers []Layer) bool {
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -747,12 +686,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
destDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -784,7 +723,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
@@ -795,12 +734,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -812,7 +751,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -822,12 +761,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
srcDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -864,13 +803,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
ManifestRef: n.Tag,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
@@ -880,7 +819,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var m Manifest
|
||||
var m manifest.Manifest
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1042,7 +981,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
fp, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with image and vision capability (image editing)",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image", "vision"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -20,19 +21,19 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := ParseNamedManifest(name)
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = ParseNamedManifest(name)
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultRegistry = "registry.ollama.ai"
|
||||
DefaultNamespace = "library"
|
||||
DefaultTag = "latest"
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
existing, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
m, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, ErrModelPathInvalid
|
||||
return nil, model.Unqualified(name)
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1147,7 +1148,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
ModifiedAt: mf.FileInfo().ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||
// default we return an empty map.
|
||||
@@ -1214,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1223,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1285,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
ms, err := Manifests(true)
|
||||
ms, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1316,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Digest: m.Digest(),
|
||||
ModifiedAt: m.FileInfo().ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1376,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1392,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||
p, err := GetBlobsPath(ib)
|
||||
p, err := manifest.BlobsPath(ib)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1410,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1428,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1603,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -1628,7 +1630,7 @@ func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
||||
blobsDir, err := GetBlobsPath("")
|
||||
blobsDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1637,7 +1639,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() {
|
||||
if _, err := Manifests(false); err != nil {
|
||||
if _, err := manifest.Manifests(false); err != nil {
|
||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||
} else {
|
||||
// clean up unused layers and manifests
|
||||
@@ -1645,12 +1647,12 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := GetManifestPath()
|
||||
manifestsPath, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -2506,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
// Check streaming preference
|
||||
isStreaming := req.Stream == nil || *req.Stream
|
||||
|
||||
contentType := "application/x-ndjson"
|
||||
if !isStreaming {
|
||||
contentType = "application/json; charset=utf-8"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
@@ -2522,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
}
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i, imgData := range req.Images {
|
||||
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
var finalResponse api.GenerateResponse
|
||||
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
@@ -2552,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
finalResponse = res
|
||||
return
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
@@ -2561,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
c.JSON(http.StatusOK, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if manifest.Config.Digest == "" {
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func (mockRunner) Ping(_ context.Context) error { return nil }
|
||||
|
||||
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateWithImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("images passed to completion request", func(t *testing.T) {
|
||||
testImage := []byte("test-image-data")
|
||||
|
||||
mock.CompletionResponse.Content = "Image processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{testImage},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify images were passed to the completion request
|
||||
if len(mock.CompletionRequest.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||
t.Errorf("image data mismatch in completion request")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||
testImage1 := []byte("test-image-1")
|
||||
testImage2 := []byte("test-image-2")
|
||||
|
||||
mock.CompletionResponse.Content = "Images processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Compare these images",
|
||||
Images: []api.ImageData{testImage1, testImage2},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify both images were passed
|
||||
if len(mock.CompletionRequest.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||
t.Errorf("first image data mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||
t.Errorf("second image data mismatch")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no images when none provided", func(t *testing.T) {
|
||||
mock.CompletionResponse.Content = "No images"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify no images in completion request
|
||||
if len(mock.CompletionRequest.Images) != 0 {
|
||||
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestImageGenerateStreamFalse tests that image generation respects stream=false
|
||||
// and returns a single JSON response instead of streaming ndjson.
|
||||
func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
Prompt: "test prompt",
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
|
||||
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Image != "base64image" {
|
||||
t.Errorf("expected image 'base64image', got %q", resp.Image)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected done=true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,12 +21,14 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -51,7 +53,7 @@ const (
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
|
||||
const MissingPart = "!MISSING!"
|
||||
|
||||
const (
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
// DefaultName returns a name with the default values for the host, namespace,
|
||||
// and tag parts. The model and digest parts are empty.
|
||||
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||
//
|
||||
// - The default host is ("registry.ollama.ai")
|
||||
// - The default namespace is ("library")
|
||||
// - The default tag is ("latest")
|
||||
// - The default protocol scheme is ("https")
|
||||
func DefaultName() Name {
|
||||
return Name{
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
ProtocolScheme: defaultProtocolScheme,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +91,11 @@ func (k partKind) String() string {
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
ProtocolScheme string
|
||||
}
|
||||
|
||||
// ParseName parses and assembles a Name from a name string. The
|
||||
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
|
||||
}
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
if ok {
|
||||
n.ProtocolScheme = scheme
|
||||
} else {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
// BaseURL returns the base URL for the registry.
|
||||
func (n Name) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: n.ProtocolScheme,
|
||||
Host: n.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||
func (n Name) DisplayNamespaceModel() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
b.WriteString(n.Model)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
||||
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
ProtocolScheme: "scheme",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
layer, err := manifest.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
// Convert LayerInfo to manifest.Layer
|
||||
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
|
||||
6
x/imagegen/cache/step.go
vendored
6
x/imagegen/cache/step.go
vendored
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// Supports both single-stream and dual-stream architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
|
||||
@@ -10,7 +10,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
// Image paths can be included in the prompt and will be extracted automatically.
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract any image paths from the prompt
|
||||
prompt, images, err := extractFileData(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
printCurrentSettings(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
||||
// Check if it's a file path, not a command
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isFile {
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Extract any image paths from the input
|
||||
prompt, images, err := extractFileData(line)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -486,3 +516,59 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractFileNames finds image file paths in the input string.
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths with image extensions
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
// extractFileData extracts image data from file paths found in the input.
|
||||
// Returns the cleaned prompt (with file paths removed) and the image data.
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
// Normalize escaped spaces
|
||||
nfp := strings.ReplaceAll(fp, "\\ ", " ")
|
||||
nfp = strings.ReplaceAll(nfp, "%20", " ")
|
||||
|
||||
data, err := getImageData(nfp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return strings.TrimSpace(input), imgs, nil
|
||||
}
|
||||
|
||||
// getImageData reads and validates image data from a file.
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
// Re-read the full file
|
||||
return os.ReadFile(filePath)
|
||||
}
|
||||
|
||||
@@ -21,8 +21,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -61,14 +59,11 @@ func main() {
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
@@ -166,60 +161,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
@@ -32,31 +33,15 @@ type ModelManifest struct {
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
// DefaultManifestDir returns the manifest storage directory.
|
||||
// Respects OLLAMA_MODELS.
|
||||
|
||||
func DefaultManifestDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
@@ -176,6 +161,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TotalTensorSize returns the total size in bytes of all tensor layers.
|
||||
func (m *ModelManifest) TotalTensorSize() int64 {
|
||||
var total int64
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
total += layer.Size
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
|
||||
57
x/imagegen/manifest_test.go
Normal file
57
x/imagegen/manifest_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTotalTensorSize(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
|
||||
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := m.TotalTensorSize()
|
||||
want := int64(6000)
|
||||
if got != want {
|
||||
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTotalTensorSizeEmpty(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{},
|
||||
},
|
||||
}
|
||||
|
||||
if got := m.TotalTensorSize(); got != 0 {
|
||||
t.Errorf("TotalTensorSize() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
// Simulate packaged/systemd environment
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
t.Setenv("HOME", "/usr/share/ollama")
|
||||
|
||||
// Manifest dir must respect OLLAMA_MODELS
|
||||
wantManifest := filepath.Join(modelsDir, "manifests")
|
||||
if got := DefaultManifestDir(); got != wantManifest {
|
||||
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
|
||||
}
|
||||
|
||||
// Blob dir must respect OLLAMA_MODELS
|
||||
wantBlobs := filepath.Join(modelsDir, "blobs")
|
||||
if got := DefaultBlobDir(); got != wantBlobs {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
@@ -16,18 +16,9 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GB is a convenience constant for gigabytes.
|
||||
const GB = 1024 * 1024 * 1024
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
// Returns nil if supported, or an error describing why it's not supported.
|
||||
func CheckPlatformSupport() error {
|
||||
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
||||
// Returns nil if memory is sufficient, or an error if not.
|
||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
||||
required := EstimateVRAM(modelName)
|
||||
if availableMemory < required {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
required/GB, availableMemory/GB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
|
||||
@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMemoryRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
availableMemory uint64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sufficient memory",
|
||||
availableMemory: 32 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exactly enough memory",
|
||||
availableMemory: 21 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insufficient memory",
|
||||
availableMemory: 16 * GB,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory",
|
||||
availableMemory: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use a non-existent model name which will default to 21GB estimate
|
||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
||||
t.Errorf("Missing VRAM estimate for %s", name)
|
||||
} else if actual != expectedVRAM {
|
||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateVRAMDefault(t *testing.T) {
|
||||
// Non-existent model should return default 21GB
|
||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
||||
if vram != 21*GB {
|
||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||
// It generates an image conditioned on the provided input images for image editing.
|
||||
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
InputImages: inputImages,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,367 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,868 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestTransformerConfig tests configuration invariants.
|
||||
func TestTransformerConfig(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Property: hidden_dim = n_heads * head_dim
|
||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: axes_dims_rope sums to head_dim
|
||||
var ropeSum int32
|
||||
for _, d := range cfg.AxesDimsRope {
|
||||
ropeSum += d
|
||||
}
|
||||
if ropeSum != cfg.HeadDim {
|
||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: in_channels = out_channels * patch_size^2
|
||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
||||
if cfg.InChannels != expectedIn {
|
||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
||||
func TestTransformerRoPE(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Test with small image dimensions
|
||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
||||
txtLen := int32(5)
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Verify shapes: [seq_len, head_dim]
|
||||
imgSeqLen := imgH * imgW
|
||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
||||
}
|
||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
||||
}
|
||||
|
||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
||||
}
|
||||
|
||||
// Verify values are finite
|
||||
imgData := ropeCache.ImgFreqs.Data()
|
||||
for i := 0; i < min(100, len(imgData)); i++ {
|
||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestTransformerForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
transformer := &Transformer{}
|
||||
if err := transformer.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load transformer: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(transformer)...)
|
||||
cfg := transformer.Config
|
||||
|
||||
// Small test inputs
|
||||
batchSize := int32(1)
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
imgSeqLen := imgH * imgW
|
||||
txtSeqLen := int32(5)
|
||||
|
||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
||||
|
||||
// Forward pass
|
||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
||||
gotShape := out.Shape()
|
||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,854 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
// Input is channels-last: [B, T, H, W, C]
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
// Works with channels-last [B, T, H, W, C] format
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
||||
var h *mlx.Array
|
||||
|
||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection (shortcut handles its own pools if present)
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
||||
|
||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// out: [B*T, 1, H*W, C]
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block with staged memory management
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// ResBlocks handle their own pools
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Upsampler handles its own pools
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block of decoder
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Each component handles its own pools; we just free inputs
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading post_quant_conv... ")
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PostQuantConv = postQuantConv
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = convIn
|
||||
fmt.Println("✓")
|
||||
|
||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
||||
fmt.Print(" Loading mid_block... ")
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MidBlock = midBlock
|
||||
fmt.Println("✓")
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
fmt.Print(" Loading up_blocks... ")
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.UpBlocks[i] = upBlock
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NormOut = normOut
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = convOut
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
||||
m := &VAEDecoder{}
|
||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] normalized latents
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Stage 1a: Denormalize and transpose
|
||||
{
|
||||
z = vae.Denormalize(z)
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// Stage 1b: PostQuantConv (handles its own pools)
|
||||
x = vae.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// Stage 1c: ConvIn (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 2: Mid block (handles its own pools)
|
||||
x = vae.MidBlock.Forward(x)
|
||||
|
||||
// Stage 3: Up blocks (each handles its own pools)
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// Stage 4a: NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = vae.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 4b: ConvOut (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 4c: Post-processing
|
||||
{
|
||||
prev := x
|
||||
// Clamp to [-1, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
// Convert back from channels-last to channels-first
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Denormalize reverses the normalization applied during encoding
|
||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
mean = mlx.ToBFloat16(mean)
|
||||
std = mlx.ToBFloat16(std)
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
||||
}
|
||||
|
||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
||||
|
||||
wShape := weight.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = weight
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
out := mlx.Linear(x, w)
|
||||
outC := w.Dim(1)
|
||||
out = mlx.Reshape(out, B, H, W, outC)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
||||
x = pad2D(x, 1, 1, 1, 1)
|
||||
return conv2D(x, weight, 1, 1)
|
||||
}
|
||||
|
||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
||||
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := w.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
||||
wFlat := mlx.Reshape(w, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 2)
|
||||
x = mlx.Take(x, colIdx, 3)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// Create repeat indices for rows
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
// Create repeat indices for columns
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
// Take along H (axis 1) then W (axis 2)
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
||||
// stride=1, padding=0 (we already padded manually)
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestVAEConfig tests configuration invariants.
|
||||
func TestVAEConfig(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Property: latents_mean and latents_std have z_dim elements
|
||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
||||
}
|
||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
||||
}
|
||||
|
||||
// Property: dim_mult defines 4 stages
|
||||
if len(cfg.DimMult) != 4 {
|
||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
||||
}
|
||||
|
||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
||||
if len(cfg.TemperalDownsample) != 3 {
|
||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
||||
func TestVAELatentsNormalization(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Verify latents_std values are all positive
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std <= 0 {
|
||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are in reasonable range (from actual model)
|
||||
for i, mean := range cfg.LatentsMean {
|
||||
if math.Abs(float64(mean)) > 5 {
|
||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
||||
}
|
||||
}
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std > 10 {
|
||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestVAEDecoderForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
vae := &VAEDecoder{}
|
||||
if err := vae.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(vae)...)
|
||||
|
||||
// Small test input: [B, C, T, H, W]
|
||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
||||
batchSize := int32(1)
|
||||
channels := int32(16)
|
||||
frames := int32(1)
|
||||
latentH := int32(4)
|
||||
latentW := int32(4)
|
||||
|
||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
||||
|
||||
// Decode
|
||||
out := vae.Decode(latents)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
||||
outShape := out.Shape()
|
||||
if outShape[0] != batchSize {
|
||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
||||
}
|
||||
if outShape[1] != 3 {
|
||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
||||
}
|
||||
if outShape[2] != frames {
|
||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
||||
}
|
||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
||||
expectedW := latentW * 16
|
||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
||||
outShape[3], outShape[4], expectedH, expectedW)
|
||||
}
|
||||
|
||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,682 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape()
|
||||
|
||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
||||
if len(shape) == 4 {
|
||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
||||
return c.forward2D(x)
|
||||
}
|
||||
|
||||
// 3D conv: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
||||
kernelH := wShape[2]
|
||||
kernelW := wShape[3]
|
||||
outC := wShape[0]
|
||||
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad spatially
|
||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
||||
|
||||
// Apply 2D conv
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = mlx.Conv2d(x, weight, 1, 0)
|
||||
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Get output spatial dims
|
||||
outH := H
|
||||
outW := W
|
||||
|
||||
// Reshape back to [B, T, H, W, C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
qkv := mlx.Linear(x, w)
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0)
|
||||
out = mlx.Linear(out, p)
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
|
||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
||||
// strideT, strideH, strideW are the strides for each dimension
|
||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
// I := wShape[1]
|
||||
kT := wShape[2]
|
||||
kH := wShape[3]
|
||||
kW := wShape[4]
|
||||
|
||||
// For temporal: if T < kT, we need to repeat frames temporally
|
||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
||||
if T < kT {
|
||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
||||
T = kT
|
||||
}
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
||||
|
||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// extractPatches3DStrided extracts 3D patches with given strides
|
||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
numPatches := outT * outH * outW
|
||||
patches := make([]*mlx.Array, numPatches)
|
||||
idx := 0
|
||||
for t := int32(0); t < outT; t++ {
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startT := t * strideT
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
// Extract patch: [B, kT, kH, kW, C]
|
||||
patch := mlx.Slice(x,
|
||||
[]int32{0, startT, startH, startW, 0},
|
||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
||||
// Flatten to [B, C*T*H*W]
|
||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
||||
}
|
||||
|
||||
// extractPatches2DStrided extracts patches with given stride
|
||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * stride
|
||||
startW := j * stride
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
@@ -1,475 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// loadImageFile loads an image from disk
|
||||
func loadImageFile(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
||||
pixels := make([]float32, width*height*3)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
pixels[idx] = float32(r) / 65535.0
|
||||
pixels[idx+1] = float32(g) / 65535.0
|
||||
pixels[idx+2] = float32(b) / 65535.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
return pixels
|
||||
}
|
||||
|
||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
||||
}
|
||||
|
||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
// Add batch dimension [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0)
|
||||
// Convert to bf16
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
||||
func clampFloat(v, weightSum float64) uint8 {
|
||||
v /= weightSum
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
if v > 255 {
|
||||
v = 255
|
||||
}
|
||||
return uint8(math.Round(v))
|
||||
}
|
||||
|
||||
// ImageDims holds dimensions for a preprocessed image
|
||||
type ImageDims struct {
|
||||
// Original image dimensions
|
||||
OrigW, OrigH int32
|
||||
// Condition image dimensions (for vision encoder)
|
||||
CondW, CondH int32
|
||||
// VAE image dimensions
|
||||
VaeW, VaeH int32
|
||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
||||
LatentW, LatentH int32
|
||||
// Patch dimensions (latent dims / patch_size)
|
||||
PatchW, PatchH int32
|
||||
}
|
||||
|
||||
// ProcessorConfig holds image processor configuration
|
||||
type ProcessorConfig struct {
|
||||
// Condition image size (target pixel area for vision encoder input)
|
||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
||||
// Pipeline resizes image to this area before passing to encode_prompt
|
||||
ConditionImageSize int32
|
||||
|
||||
// VAE image size (target pixel area)
|
||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
||||
VAEImageSize int32
|
||||
|
||||
// Image normalization (ImageNet stats for vision encoder)
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// defaultProcessorConfig returns default processor config
|
||||
func defaultProcessorConfig() *ProcessorConfig {
|
||||
return &ProcessorConfig{
|
||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
||||
type Processor struct {
|
||||
Config *ProcessorConfig
|
||||
}
|
||||
|
||||
// Load loads the processor config
|
||||
func (p *Processor) Load(path string) error {
|
||||
p.Config = defaultProcessorConfig()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := bounds.Dx()
|
||||
origH := bounds.Dy()
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize
|
||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
|
||||
return condImage, vaeImage, nil
|
||||
}
|
||||
|
||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
||||
// Used by edit_plus pipeline for multi-image input
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImage converts image to tensor for vision encoder
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
||||
resized := resizeImageBicubic(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
if normalize {
|
||||
arr = p.normalizeImageNet(arr)
|
||||
}
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
|
||||
// Scale to [-1, 1]: arr * 2 - 1
|
||||
arr = mlx.MulScalar(arr, 2.0)
|
||||
arr = mlx.AddScalar(arr, -1.0)
|
||||
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
|
||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
||||
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
||||
// Round to factor
|
||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
||||
|
||||
// Ensure minimum factor size
|
||||
if hBar < factor {
|
||||
hBar = factor
|
||||
}
|
||||
if wBar < factor {
|
||||
wBar = factor
|
||||
}
|
||||
|
||||
// Check pixel constraints
|
||||
total := hBar * wBar
|
||||
if total > maxPixels {
|
||||
// Scale down
|
||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
||||
} else if total < minPixels {
|
||||
// Scale up
|
||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
||||
width := math.Sqrt(float64(targetArea) * ratio)
|
||||
height := width / ratio
|
||||
|
||||
m := float64(multiple)
|
||||
width = math.Round(width/m) * m
|
||||
height = math.Round(height/m) * m
|
||||
|
||||
// Ensure minimum dimensions
|
||||
if width < m {
|
||||
width = m
|
||||
}
|
||||
if height < m {
|
||||
height = m
|
||||
}
|
||||
|
||||
return int32(width), int32(height)
|
||||
}
|
||||
|
||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
||||
lanczos3 := &draw.Kernel{
|
||||
Support: 3.0,
|
||||
At: func(t float64) float64 {
|
||||
if t == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if t < 0 {
|
||||
t = -t
|
||||
}
|
||||
if t >= 3.0 {
|
||||
return 0.0
|
||||
}
|
||||
// sinc(t) * sinc(t/3)
|
||||
piT := math.Pi * t
|
||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
||||
},
|
||||
}
|
||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
// Convert to RGBA if needed
|
||||
var src *image.RGBA
|
||||
if rgba, ok := img.(*image.RGBA); ok {
|
||||
src = rgba
|
||||
} else {
|
||||
src = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
src.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
||||
cubic := func(x float64) float64 {
|
||||
if x < 0 {
|
||||
x = -x
|
||||
}
|
||||
if x < 1 {
|
||||
return 1.5*x*x*x - 2.5*x*x + 1
|
||||
}
|
||||
if x < 2 {
|
||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Horizontal pass: srcW -> width, keep srcH rows
|
||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
||||
for y := 0; y < srcH; y++ {
|
||||
for dstX := 0; dstX < width; dstX++ {
|
||||
// PIL coordinate mapping: center-to-center
|
||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
||||
baseX := int(math.Floor(srcXf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for i := -1; i <= 2; i++ {
|
||||
sx := baseX + i
|
||||
if sx < 0 {
|
||||
sx = 0
|
||||
}
|
||||
if sx >= srcW {
|
||||
sx = srcW - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
||||
c := src.RGBAAt(sx, y)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
temp.SetRGBA(dstX, y, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical pass: srcH -> height
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for x := 0; x < width; x++ {
|
||||
for dstY := 0; dstY < height; dstY++ {
|
||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
||||
baseY := int(math.Floor(srcYf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for j := -1; j <= 2; j++ {
|
||||
sy := baseY + j
|
||||
if sy < 0 {
|
||||
sy = 0
|
||||
}
|
||||
if sy >= srcH {
|
||||
sy = srcH - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
||||
c := temp.RGBAAt(x, sy)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
dst.SetRGBA(x, dstY, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
||||
const vaeScaleFactor int32 = 8
|
||||
const patchSize int32 = 2
|
||||
|
||||
condImages := make([]*mlx.Array, len(imagePaths))
|
||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
||||
dims := make([]ImageDims, len(imagePaths))
|
||||
|
||||
for i, imagePath := range imagePaths {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := int32(bounds.Dx())
|
||||
origH := int32(bounds.Dy())
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Calculate derived dimensions
|
||||
latentW := vaeW / vaeScaleFactor
|
||||
latentH := vaeH / vaeScaleFactor
|
||||
patchW := latentW / patchSize
|
||||
patchH := latentH / patchSize
|
||||
|
||||
dims[i] = ImageDims{
|
||||
OrigW: origW,
|
||||
OrigH: origH,
|
||||
CondW: condW,
|
||||
CondH: condH,
|
||||
VaeW: vaeW,
|
||||
VaeH: vaeH,
|
||||
LatentW: latentW,
|
||||
LatentH: latentH,
|
||||
PatchW: patchW,
|
||||
PatchH: patchH,
|
||||
}
|
||||
|
||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
}
|
||||
|
||||
return condImages, vaeImages, dims, nil
|
||||
}
|
||||
@@ -1,625 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
||||
// It reuses components from qwen_image where possible.
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
}
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Processor *Processor // Image processor for vision encoder
|
||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
||||
VAE *VAE // Combined encoder + decoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image-Edit model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer from processor directory
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
processorPath := filepath.Join(modelPath, "processor")
|
||||
tok, err := tokenizer.Load(processorPath)
|
||||
if err != nil {
|
||||
// Fallback to tokenizer directory
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err = tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load processor (image preprocessing config)
|
||||
fmt.Print(" Loading processor... ")
|
||||
m.Processor = &Processor{}
|
||||
if err := m.Processor.Load(processorPath); err != nil {
|
||||
return fmt.Errorf("processor: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer (reuse qwen_image)
|
||||
m.Transformer = &qwen_image.Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE (encoder + decoder)
|
||||
m.VAE = &VAE{}
|
||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edit edits an image based on a text prompt.
|
||||
// inputImagePath: path to input image
|
||||
// prompt: text description of desired edit
|
||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// EditFromConfig edits images using the unified config struct.
|
||||
// Accepts one or more input images.
|
||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if len(inputImagePaths) == 0 {
|
||||
return nil, fmt.Errorf("no input images provided")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := m.edit(inputImagePaths, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EditImage implements model.ImageEditModel interface.
|
||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// EditMultiImage edits using multiple source images.
|
||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
return m.EditFromConfig(inputImagePaths, cfg)
|
||||
}
|
||||
|
||||
// edit is the internal editing pipeline that handles one or more images.
|
||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
|
||||
// Load and preprocess all input images
|
||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
||||
}
|
||||
for _, img := range condImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
for _, img := range vaeImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
mlx.Eval(append(condImages, vaeImages...)...)
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
vaeScaleFactor := int32(8)
|
||||
|
||||
// Output dimensions - if not specified, use first input image dimensions
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = inputDims[0].VaeW
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = inputDims[0].VaeH
|
||||
}
|
||||
|
||||
// Output (noise) latent dimensions
|
||||
outLatentH := cfg.Height / vaeScaleFactor
|
||||
outLatentW := cfg.Width / vaeScaleFactor
|
||||
outPH := outLatentH / tcfg.PatchSize
|
||||
outPW := outLatentW / tcfg.PatchSize
|
||||
noiseSeqLen := outPH * outPW
|
||||
imgSeqLen := noiseSeqLen
|
||||
|
||||
// Encode prompt with all images for conditioning
|
||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
|
||||
var negEmb *mlx.Array
|
||||
if useCFG {
|
||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(negEmb)
|
||||
mlx.Eval(negEmb)
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
for i, vaeImage := range vaeImages {
|
||||
imageLatents := m.VAE.Encode(vaeImage)
|
||||
imageLatents = m.VAE.Normalize(imageLatents)
|
||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
||||
mlx.Keep(packed)
|
||||
mlx.Eval(packed)
|
||||
allImageLatentsPacked[i] = packed
|
||||
}
|
||||
|
||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
||||
mlx.Keep(imageLatentsPacked)
|
||||
mlx.Eval(imageLatentsPacked)
|
||||
|
||||
// Scheduler
|
||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
||||
|
||||
// Init noise latents in packed format
|
||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// RoPE cache
|
||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
mlx.Eval(timestep)
|
||||
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
||||
}
|
||||
|
||||
// Free denoising temporaries
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
// Decode latents
|
||||
decoded := m.decodeAndPostprocess(latents)
|
||||
latents.Free()
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
||||
// This prevents CFG from inflating magnitude too much.
|
||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
||||
// Upcast to float32 for precision
|
||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
||||
|
||||
// CFG: pred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posF32, negF32)
|
||||
scaledDiff := mlx.MulScalar(diff, scale)
|
||||
combPred := mlx.Add(negF32, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
|
||||
mlx.Eval(output)
|
||||
return mlx.ToBFloat16(output)
|
||||
}
|
||||
|
||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
||||
latents = m.VAE.Denormalize(latents)
|
||||
decoded := m.VAE.Decode(latents)
|
||||
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
mlx.Eval(decoded)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
||||
// Handles single or multiple input images with different resolutions.
|
||||
//
|
||||
// Parameters:
|
||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
||||
// - txtLen: text sequence length
|
||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
||||
//
|
||||
// Returns RoPE cache where:
|
||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
||||
// - Following positions are for each input image (interpolated from output res)
|
||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
|
||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame position
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = posFreqsT[framePos][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame -1: use last row of negative frame frequencies
|
||||
negFrameIdx := maxIdx - 1
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Total image sequence length: noise + all input images
|
||||
noiseSeqLen := outPH * outPW
|
||||
totalImgLen := noiseSeqLen
|
||||
for _, dims := range inputDims {
|
||||
totalImgLen += dims.PatchH * dims.PatchW
|
||||
}
|
||||
|
||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
||||
idx := int32(0)
|
||||
|
||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
||||
for y := int32(0); y < outPH; y++ {
|
||||
for x := int32(0); x < outPW; x++ {
|
||||
row := computePosFreqs(0, y, x)
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
|
||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
||||
// (_compute_condition_freqs), positive indices for others.
|
||||
numImages := len(inputDims)
|
||||
lastImgIdx := numImages - 1
|
||||
for imgIdx, dims := range inputDims {
|
||||
inPH := dims.PatchH
|
||||
inPW := dims.PatchW
|
||||
|
||||
// Determine frame index for this image
|
||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
||||
|
||||
// Map each input position to an output position using linear interpolation
|
||||
for y := int32(0); y < inPH; y++ {
|
||||
for x := int32(0); x < inPW; x++ {
|
||||
// Interpolate: map input (y, x) to output grid position
|
||||
// This is the key fix from DiffSynth's forward_sampling
|
||||
var yOut, xOut int32
|
||||
if inPH == 1 {
|
||||
yOut = 0
|
||||
} else {
|
||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
||||
yOut = y * (outPH - 1) / (inPH - 1)
|
||||
}
|
||||
if inPW == 1 {
|
||||
xOut = 0
|
||||
} else {
|
||||
xOut = x * (outPW - 1) / (inPW - 1)
|
||||
}
|
||||
|
||||
var row []float32
|
||||
if useNegFrame {
|
||||
// Last image in multi-image uses frame -1
|
||||
row = computeNegFrameFreqs(yOut, xOut)
|
||||
} else {
|
||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
||||
frameIdx := int32(imgIdx + 1)
|
||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
||||
}
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies - start after max video index
|
||||
maxVidIdx := max(outPH/2, outPW/2)
|
||||
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &qwen_image.RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
// Expected values from Python:
|
||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
||||
expectedFreqsT := []float64{
|
||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
||||
}
|
||||
|
||||
expectedFreqsH_first4 := []float64{
|
||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
||||
}
|
||||
|
||||
expectedFreqsH_last4 := []float64{
|
||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
||||
}
|
||||
|
||||
// Test temporal frequencies (dim=16)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
if len(freqsT) != 8 {
|
||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
||||
}
|
||||
for i, expected := range expectedFreqsT {
|
||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test height/width frequencies (dim=56)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
||||
if len(freqsH) != 28 {
|
||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
||||
}
|
||||
for i, expected := range expectedFreqsH_first4 {
|
||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
||||
}
|
||||
}
|
||||
for i, expected := range expectedFreqsH_last4 {
|
||||
idx := 24 + i // last 4 of 28
|
||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
||||
func TestMakeFreqTable(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Test positive table
|
||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
|
||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
||||
if posTable[0][i] != 1.0 {
|
||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
||||
}
|
||||
if posTable[0][i+1] != 0.0 {
|
||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
||||
}
|
||||
}
|
||||
|
||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
||||
}
|
||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
||||
}
|
||||
|
||||
// Test negative table
|
||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
||||
|
||||
// negTable[4095] corresponds to position -1
|
||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
||||
}
|
||||
|
||||
// negTable[4094] corresponds to position -2
|
||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
||||
cos2 := math.Cos(2.0)
|
||||
sin2 := math.Sin(2.0)
|
||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
||||
if !mlx.GPUIsAvailable() {
|
||||
t.Skip("GPU not available")
|
||||
}
|
||||
|
||||
mlx.SetDefaultDeviceCPU()
|
||||
|
||||
// 4x4 patch grid, single image
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
txtLen := int32(5)
|
||||
axesDims := []int32{16, 56, 56}
|
||||
|
||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
||||
|
||||
// Check shapes
|
||||
imgShape := cache.ImgFreqs.Shape()
|
||||
if imgShape[0] != 16 { // 4*4 patches
|
||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
||||
}
|
||||
|
||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
||||
mlx.Eval(imgFreqsCPU)
|
||||
imgData := imgFreqsCPU.Data()
|
||||
|
||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
||||
for i := 0; i < 16; i += 2 {
|
||||
cosVal := imgData[i]
|
||||
sinVal := imgData[i+1]
|
||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
||||
}
|
||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
||||
}
|
||||
}
|
||||
|
||||
cache.ImgFreqs.Free()
|
||||
cache.TxtFreqs.Free()
|
||||
}
|
||||
|
||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
||||
func TestScaleRopePositions(t *testing.T) {
|
||||
// For a 4x4 grid with scale_rope=True:
|
||||
// hHalf = 2, wHalf = 2
|
||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
//
|
||||
// Height positions:
|
||||
// y=0: -(4-2) + 0 = -2
|
||||
// y=1: -(4-2) + 1 = -1
|
||||
// y=2: 2 - 2 = 0
|
||||
// y=3: 3 - 2 = 1
|
||||
//
|
||||
// Same for width
|
||||
|
||||
pH, pW := int32(4), int32(4)
|
||||
hHalf := pH / 2
|
||||
wHalf := pW / 2
|
||||
hNegCount := pH - hHalf
|
||||
wNegCount := pW - wHalf
|
||||
|
||||
expectedH := []int32{-2, -1, 0, 1}
|
||||
expectedW := []int32{-2, -1, 0, 1}
|
||||
|
||||
for y := int32(0); y < pH; y++ {
|
||||
var hPos int32
|
||||
if y < hNegCount {
|
||||
hPos = -(pH - hHalf) + y
|
||||
} else {
|
||||
hPos = y - hNegCount
|
||||
}
|
||||
if hPos != expectedH[y] {
|
||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
||||
}
|
||||
}
|
||||
|
||||
for x := int32(0); x < pW; x++ {
|
||||
var wPos int32
|
||||
if x < wNegCount {
|
||||
wPos = -(pW - wHalf) + x
|
||||
} else {
|
||||
wPos = x - wNegCount
|
||||
}
|
||||
if wPos != expectedW[x] {
|
||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
||||
func TestRoPEHeadDimensions(t *testing.T) {
|
||||
// axes_dims_rope = [16, 56, 56]
|
||||
// Each dimension uses half the values for frequencies
|
||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
||||
|
||||
axesDims := []int32{16, 56, 56}
|
||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
||||
expectedHeadDim := expectedFreqs * 2
|
||||
|
||||
if expectedFreqs != 64 {
|
||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
||||
}
|
||||
if expectedHeadDim != 128 {
|
||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
||||
}
|
||||
|
||||
// This should match the transformer's attention head dimension
|
||||
// hidden_size = 3072, num_heads = 24
|
||||
// head_dim = 3072 / 24 = 128
|
||||
}
|
||||
|
||||
@@ -1,642 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// VAE is the full VAE with encoder and decoder
|
||||
type VAE struct {
|
||||
Config *VAEConfig
|
||||
Encoder *VAEEncoder
|
||||
Decoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the VAE from a directory
|
||||
func (m *VAE) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load weights as f32 for quality (matches Python default behavior)
|
||||
// VAE decoder precision is critical for final image quality
|
||||
fmt.Print(" Loading weights as f32... ")
|
||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
// Load encoder
|
||||
fmt.Print(" Loading encoder... ")
|
||||
m.Encoder = &VAEEncoder{}
|
||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load decoder
|
||||
fmt.Print(" Loading decoder... ")
|
||||
m.Decoder = &VAEDecoder{}
|
||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("decoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
||||
return m.Encoder.Encode(x)
|
||||
}
|
||||
|
||||
// Decode decodes latents to image
|
||||
// z: [B, C, T, H, W] latents (denormalized)
|
||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
||||
return m.Decoder.Decode(z)
|
||||
}
|
||||
|
||||
// Normalize applies latent normalization
|
||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
// Mean/std are f32, will match z dtype through broadcasting
|
||||
return mlx.Div(mlx.Sub(z, mean), std)
|
||||
}
|
||||
|
||||
// Denormalize reverses latent normalization
|
||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
// Convert latents to f32 for VAE decoder quality
|
||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// VAEEncoder is the encoder part of the VAE
|
||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
||||
// - Blocks 0,1: ResBlocks (base_dim)
|
||||
// - Block 2: Downsample
|
||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
||||
// - Block 5: Downsample + temporal
|
||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
||||
// - Block 8: Downsample + temporal
|
||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
||||
type VAEEncoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
ConvIn *CausalConv3d
|
||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
||||
MidBlock *MidBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
QuantConv *CausalConv3d
|
||||
}
|
||||
|
||||
// EncoderBlock is either a ResBlock or a Downsample
|
||||
type EncoderBlock interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
IsDownsample() bool
|
||||
}
|
||||
|
||||
// EncoderResBlock wraps ResBlock
|
||||
type EncoderResBlock struct {
|
||||
*ResBlock
|
||||
}
|
||||
|
||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
||||
|
||||
// EncoderDownsample is a downsample layer
|
||||
type EncoderDownsample struct {
|
||||
Resample *CausalConv3d
|
||||
TimeConv *CausalConv3d // Optional temporal downsample
|
||||
}
|
||||
|
||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
||||
|
||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Spatial downsample with stride 2
|
||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
||||
x = d.forwardSpatialDownsample(x)
|
||||
|
||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
||||
// Since we don't support streaming, we skip time_conv entirely.
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := d.Resample.Weight.Shape()
|
||||
outC := wShape[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
||||
|
||||
// Apply 2D conv with stride 2
|
||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
|
||||
if d.Resample.Bias != nil {
|
||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
||||
outH := (H + 1) / 2
|
||||
outW := (W + 1) / 2
|
||||
|
||||
// Reshape back to [B, T, H', W', C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// loadFromWeights loads the encoder from pre-loaded weights
|
||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
e.Config = cfg
|
||||
|
||||
// Conv in
|
||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvIn = convIn
|
||||
|
||||
// Encoder uses flat block structure:
|
||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
||||
|
||||
// Track dimensions
|
||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
||||
blockIdx := 0
|
||||
|
||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
||||
inDim := cfg.BaseDim
|
||||
if stage > 0 {
|
||||
inDim = dims[stage-1]
|
||||
}
|
||||
outDim := dims[stage]
|
||||
|
||||
// ResBlocks for this stage (num_res_blocks per stage)
|
||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
currentInDim := inDim
|
||||
if r > 0 {
|
||||
currentInDim = outDim
|
||||
}
|
||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, block)
|
||||
blockIdx++
|
||||
}
|
||||
|
||||
// Downsample after each stage except the last
|
||||
if stage < len(cfg.DimMult)-1 {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, down)
|
||||
blockIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.MidBlock = midBlock
|
||||
|
||||
// Norm out
|
||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.NormOut = normOut
|
||||
|
||||
// Conv out
|
||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvOut = convOut
|
||||
|
||||
// Quant conv
|
||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.QuantConv = quantConv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncoderResBlock{block}, nil
|
||||
}
|
||||
|
||||
// newEncoderDownsample creates a downsample layer for the encoder
|
||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var timeConv *CausalConv3d
|
||||
if temporal {
|
||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
||||
}
|
||||
|
||||
return &EncoderDownsample{
|
||||
Resample: resample,
|
||||
TimeConv: timeConv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(x)
|
||||
|
||||
// Conv in
|
||||
x = e.ConvIn.Forward(x)
|
||||
|
||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
||||
for _, block := range e.Blocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = e.MidBlock.Forward(x)
|
||||
|
||||
// Norm + silu
|
||||
{
|
||||
prev := x
|
||||
x = e.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Conv out
|
||||
{
|
||||
prev := x
|
||||
x = e.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Quant conv
|
||||
{
|
||||
prev := x
|
||||
x = e.QuantConv.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Get mode from distribution (first half of channels = mean)
|
||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
||||
shape := x.Shape()
|
||||
latentC := shape[4] / 2
|
||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
||||
|
||||
// Convert back to channels-first [N, C, T, H, W]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the decoder part of the VAE
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// loadFromWeights loads the decoder from pre-loaded weights
|
||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
d.Config = cfg
|
||||
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PostQuantConv = postQuantConv
|
||||
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvIn = convIn
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.MidBlock = midBlock
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.UpBlocks[i] = upBlock
|
||||
}
|
||||
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NormOut = normOut
|
||||
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvOut = convOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] denormalized latents
|
||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Convert from channels-first to channels-last
|
||||
{
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// PostQuantConv
|
||||
x = d.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// ConvIn
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = d.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks
|
||||
for _, upBlock := range d.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = d.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// ConvOut
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Post-processing: clamp and convert back to channels-first
|
||||
{
|
||||
prev := x
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// DownBlock handles downsampling in encoder
|
||||
type DownBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Downsampler *Downsample
|
||||
}
|
||||
|
||||
// newDownBlock creates a down block
|
||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var downsampler *Downsample
|
||||
if downsampleMode != "" {
|
||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
||||
}
|
||||
|
||||
return &DownBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Downsampler: downsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies down block
|
||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range d.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if d.Downsampler != nil {
|
||||
prev := x
|
||||
x = d.Downsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Downsample handles spatial downsampling
|
||||
type Downsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newDownsample creates a downsampler
|
||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Downsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := d.Conv.Shape()[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
||||
// For 3x3 stride 2: pad 1 on all sides
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
|
||||
// Conv with stride 2 using manual strided patching
|
||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
if d.Bias != nil {
|
||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -25,11 +26,12 @@ import (
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
@@ -46,6 +48,13 @@ type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||
// Models that support input images for editing should implement this interface.
|
||||
type ImageEditModel interface {
|
||||
ImageModel
|
||||
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
@@ -78,14 +87,6 @@ func Execute(args []string) error {
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
requiredMemory := imagegen.EstimateVRAM(*modelName)
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
@@ -161,6 +162,44 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and decode input images
|
||||
const maxInputImages = 2
|
||||
if len(req.Images) > maxInputImages {
|
||||
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var inputImages []image.Image
|
||||
if len(req.Images) > 0 {
|
||||
// TODO: add memory check for input images
|
||||
|
||||
inputImages = make([]image.Image, len(req.Images))
|
||||
for i, imgBytes := range req.Images {
|
||||
img, err := imagegen.DecodeImage(imgBytes)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
inputImages[i] = img
|
||||
}
|
||||
slog.Info("decoded input images", "count", len(inputImages))
|
||||
|
||||
// Default width/height to first input image dimensions, scaled to max 1024
|
||||
bounds := inputImages[0].Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
if w > 1024 || h > 1024 {
|
||||
if w > h {
|
||||
h = h * 1024 / w
|
||||
w = 1024
|
||||
} else {
|
||||
w = w * 1024 / h
|
||||
h = 1024
|
||||
}
|
||||
}
|
||||
req.Width = int32(w)
|
||||
req.Height = int32(h)
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -192,7 +231,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||
var img *mlx.Array
|
||||
var err error
|
||||
if len(inputImages) > 0 {
|
||||
editModel, ok := s.model.(ImageEditModel)
|
||||
if !ok {
|
||||
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||
} else {
|
||||
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -104,11 +105,17 @@ func NewServer(modelName string) (*Server, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Get total weight size from manifest
|
||||
var weightSize uint64
|
||||
if manifest, err := LoadManifest(modelName); err == nil {
|
||||
weightSize = uint64(manifest.TotalTensorSize())
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: EstimateVRAM(modelName),
|
||||
vramSize: weightSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
@@ -226,19 +233,27 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Extract raw image bytes from llm.ImageData slice
|
||||
var images [][]byte
|
||||
for _, img := range req.Images {
|
||||
images = append(images, img.Data)
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
@@ -260,7 +275,8 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
@@ -38,40 +38,6 @@ func TestPlatformSupport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryRequirementsError verifies memory check returns clear error.
|
||||
func TestMemoryRequirementsError(t *testing.T) {
|
||||
// Test with insufficient memory
|
||||
err := CheckMemoryRequirements("test-model", 8*GB)
|
||||
if err == nil {
|
||||
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
|
||||
}
|
||||
|
||||
// Test with sufficient memory
|
||||
err = CheckMemoryRequirements("test-model", 32*GB)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
|
||||
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
|
||||
// Unknown model should return default (21GB)
|
||||
vram := EstimateVRAM("unknown-model")
|
||||
if vram < 10*GB || vram > 100*GB {
|
||||
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
|
||||
}
|
||||
|
||||
// Verify known pipeline estimates exist and are reasonable
|
||||
for name, estimate := range modelVRAMEstimates {
|
||||
if estimate < 10*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
|
||||
}
|
||||
if estimate > 200*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||
// This is a compile-time check but we document it as a test.
|
||||
func TestServerInterfaceCompliance(t *testing.T) {
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// modelConfig represents the HuggingFace config.json structure
|
||||
@@ -35,22 +36,22 @@ type modelConfig struct {
|
||||
|
||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var config modelConfig
|
||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
||||
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
// Calculate total tensor bytes from manifest layers
|
||||
var totalBytes int64
|
||||
var tensorCount int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
totalBytes += layer.Size
|
||||
tensorCount++
|
||||
}
|
||||
@@ -151,27 +152,30 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
return getTensorInfoFromManifest(manifest)
|
||||
return getTensorInfoFromManifest(mf)
|
||||
}
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
||||
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
blobPath := manifest.BlobPath(layer.Digest)
|
||||
blobPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
@@ -197,15 +201,15 @@ func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor,
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
@@ -217,7 +221,7 @@ func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
@@ -451,8 +451,14 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create test tensor blobs
|
||||
tensors := []struct {
|
||||
@@ -463,26 +469,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "model.embed_tokens.weight",
|
||||
digest: "sha256:abc123",
|
||||
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
|
||||
dtype: "BF16",
|
||||
shape: []int64{262144, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.layers.0.self_attn.q_proj.weight",
|
||||
digest: "sha256:def456",
|
||||
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
|
||||
dtype: "BF16",
|
||||
shape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.norm.weight",
|
||||
digest: "sha256:ghi789",
|
||||
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
|
||||
dtype: "F32",
|
||||
shape: []int64{2560},
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
var layers []imagegen.ManifestLayer
|
||||
var layers []manifest.Layer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
@@ -498,15 +504,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
// Write blob file
|
||||
blobName := "sha256-" + tensor.digest[7:]
|
||||
blobPath := filepath.Join(tempDir, blobName)
|
||||
// Write blob file using the digest format expected by GetBlobsPath
|
||||
blobPath, err := manifest.BlobsPath(tensor.digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
layers = append(layers, manifest.Layer{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: tensor.digest,
|
||||
Size: int64(buf.Len() + 1000), // header + fake data
|
||||
Name: tensor.name,
|
||||
@@ -514,21 +522,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add a non-tensor layer (should be skipped)
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
layers = append(layers, manifest.Layer{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: "sha256:config",
|
||||
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
Size: 100,
|
||||
Name: "config.json",
|
||||
})
|
||||
|
||||
manifest := &imagegen.ModelManifest{
|
||||
Manifest: &imagegen.Manifest{
|
||||
Layers: layers,
|
||||
},
|
||||
BlobDir: tempDir,
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(manifest)
|
||||
result, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user