Compare commits

..

2 Commits

Author SHA1 Message Date
Bruce MacDonald
cc3ac5fee3 docs: update instructions for ollama config command
These tools can be automatically configured using the new ollama config command
2026-01-21 17:03:41 -08:00
Jeffrey Morgan
b5d0f72f16 x/imagegen: remove qwen_image and qwen_image_edit models (#13827)
Remove the Qwen image generation and image editing model packages
to clean up the codebase. These models will be reintroduced later.

- Delete x/imagegen/models/qwen_image/ (10 files)
- Delete x/imagegen/models/qwen_image_edit/ (5 files)
- Remove related CLI flags and imports from cmd/engine/main.go
- Update comments in cache/step.go to remove Qwen-specific references
2026-01-21 13:37:08 -08:00
30 changed files with 147 additions and 8042 deletions

View File

@@ -2,7 +2,7 @@
title: Claude Code
---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
@@ -26,6 +26,16 @@ irm https://claude.ai/install.ps1 | iex
## Usage with Ollama
Configure Claude Code to use Ollama:
```shell
ollama config claude
```
This will prompt you to select a model and automatically configure Claude Code to use Ollama.
<Accordion title="Manual Configuration">
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
@@ -47,7 +57,9 @@ Or run with environment variables inline:
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
</Accordion>
<Note>Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.</Note>
## Connecting to ollama.com
@@ -75,4 +87,4 @@ claude --model glm-4.7:cloud
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -2,22 +2,31 @@
title: Codex
---
Codex is OpenAI's agentic coding tool for the command line.
## Install
Install the [Codex CLI](https://developers.openai.com/codex/cli/):
```
```shell
npm install -g @openai/codex
```
## Usage with Ollama
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
Configure Codex to use Ollama:
```shell
ollama config codex
```
This will prompt you to select a model and automatically configure Codex to use Ollama.
<Accordion title="Manual Configuration">
To use `codex` with Ollama, use the `--oss` flag:
```
```shell
codex --oss
```
@@ -25,20 +34,22 @@ codex --oss
By default, codex will use the local `gpt-oss:20b` model. However, you can specify a different model with the `-m` flag:
```
```shell
codex --oss -m gpt-oss:120b
```
### Cloud Models
```
```shell
codex --oss -m gpt-oss:120b-cloud
```
</Accordion>
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
## Connecting to ollama.com
Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
To use ollama.com directly, edit your `~/.codex/config.toml` file to point to ollama.com.

View File

@@ -2,6 +2,7 @@
title: Droid
---
Droid is Factory's agentic coding tool for the command line.
## Install
@@ -11,66 +12,80 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
Add a local configuration block to `~/.factory/config.json`:
Configure Droid to use Ollama:
```shell
ollama config droid
```
This will prompt you to select models and automatically configure Droid to use Ollama.
<Accordion title="Manual Configuration">
Add a local configuration block to `~/.factory/settings.json`:
```json
{
"custom_models": [
"customModels": [
{
"model_display_name": "qwen3-coder [Ollama]",
"model": "qwen3-coder",
"base_url": "http://localhost:11434/v1/",
"api_key": "not-needed",
"displayName": "qwen3-coder [Ollama]",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"max_tokens": 32000
"maxOutputTokens": 32000
}
]
}
```
Adjust `maxOutputTokens` based on your model's context length (the automated setup detects this automatically).
### Cloud Models
## Cloud Models
`qwen3-coder:480b-cloud` is the recommended model for use with Droid.
Add the cloud configuration block to `~/.factory/config.json`:
Add the cloud configuration block to `~/.factory/settings.json`:
```json
{
"custom_models": [
"customModels": [
{
"model_display_name": "qwen3-coder [Ollama Cloud]",
"model": "qwen3-coder:480b-cloud",
"base_url": "http://localhost:11434/v1/",
"api_key": "not-needed",
"displayName": "qwen3-coder:480b-cloud [Ollama]",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"max_tokens": 128000
"maxOutputTokens": 128000
}
]
}
```
</Accordion>
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Add the cloud configuration block to `~/.factory/config.json`:
2. Add the cloud configuration block to `~/.factory/settings.json`:
```json
{
"custom_models": [
"customModels": [
{
"model_display_name": "qwen3-coder [Ollama Cloud]",
"model": "qwen3-coder:480b",
"base_url": "https://ollama.com/v1/",
"api_key": "OLLAMA_API_KEY",
"displayName": "qwen3-coder:480b [Ollama Cloud]",
"baseUrl": "https://ollama.com/v1",
"apiKey": "OLLAMA_API_KEY",
"provider": "generic-chat-completion-api",
"max_tokens": 128000
"maxOutputTokens": 128000
}
]
}
```
Run `droid` in a new terminal to load the new settings.
Run `droid` in a new terminal to load the new settings.

View File

@@ -0,0 +1,63 @@
---
title: OpenCode
---
OpenCode is an agentic coding tool for the terminal.
## Install
Install [OpenCode](https://opencode.ai):
```shell
curl -fsSL https://opencode.ai/install | bash
```
## Usage with Ollama
Configure OpenCode to use Ollama:
```shell
ollama config opencode
```
This will prompt you to select models and automatically configure OpenCode to use Ollama.
<Accordion title="Manual Configuration">
Add the Ollama provider to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder [Ollama]"
}
}
}
}
}
```
</Accordion>
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Recommended Models
### Cloud models
- `qwen3-coder:480b` - Large coding model
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -609,49 +609,3 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
c.Next()
}
}
func ImageEditsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageEditRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.Image == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
return
}
genReq, err := openai.FromImageEditRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -1112,129 +1112,3 @@ func TestImageWriterResponse(t *testing.T) {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}
func TestImageEditsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
// Base64-encoded test image (1x1 pixel PNG)
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
testCases := []testCase{
{
name: "image edit basic",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
},
},
{
name: "image edit with size",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
Width: 512,
Height: 768,
},
},
{
name: "image edit missing prompt",
body: `{
"model": "test-model",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing model",
body: `{
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing image",
body: `{
"model": "test-model",
"prompt": "make it blue"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "image is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}

View File

@@ -794,47 +794,3 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
Data: data,
}
}
// ImageEditRequest is an OpenAI-compatible image edit request.
type ImageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image string `json:"image"` // Base64-encoded image data
Size string `json:"size,omitempty"` // e.g., "1024x1024"
Seed *int64 `json:"seed,omitempty"`
}
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Decode the input image
if r.Image != "" {
imgData, err := decodeImageURL(r.Image)
if err != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
}
req.Images = append(req.Images, imgData)
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
if req.Options == nil {
req.Options = map[string]any{}
}
req.Options["seed"] = *r.Seed
}
return req, nil
}

View File

@@ -448,86 +448,3 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
func TestFromImageEditRequest_Basic(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if result.Prompt != "make it blue" {
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
}
if len(result.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Images))
}
}
func TestFromImageEditRequest_WithSize(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Size: "512x768",
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Width != 512 {
t.Errorf("expected width 512, got %d", result.Width)
}
if result.Height != 768 {
t.Errorf("expected height 768, got %d", result.Height)
}
}
func TestFromImageEditRequest_WithSeed(t *testing.T) {
seed := int64(12345)
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Seed: &seed,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options == nil {
t.Fatal("expected options to be set")
}
if result.Options["seed"] != seed {
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
}
}
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: "not-valid-base64",
}
_, err := FromImageEditRequest(req)
if err == nil {
t.Error("expected error for invalid image")
}
}

View File

@@ -1604,9 +1604,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
// OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -2524,11 +2523,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
}
var images []llm.ImageData
for i, imgData := range req.Images {
images = append(images, llm.ImageData{ID: i, Data: imgData})
}
var streamStarted bool
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
@@ -2536,7 +2530,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{

View File

@@ -2193,157 +2193,3 @@ func TestGenerateUnload(t *testing.T) {
}
})
}
func TestGenerateWithImages(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("images passed to completion request", func(t *testing.T) {
testImage := []byte("test-image-data")
mock.CompletionResponse.Content = "Image processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Describe this image",
Images: []api.ImageData{testImage},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify images were passed to the completion request
if len(mock.CompletionRequest.Images) != 1 {
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
t.Errorf("image data mismatch in completion request")
}
if mock.CompletionRequest.Images[0].ID != 0 {
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
}
})
t.Run("multiple images passed to completion request", func(t *testing.T) {
testImage1 := []byte("test-image-1")
testImage2 := []byte("test-image-2")
mock.CompletionResponse.Content = "Images processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Compare these images",
Images: []api.ImageData{testImage1, testImage2},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify both images were passed
if len(mock.CompletionRequest.Images) != 2 {
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
t.Errorf("first image data mismatch")
}
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
t.Errorf("second image data mismatch")
}
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
t.Errorf("expected image IDs 0 and 1, got %d and %d",
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
}
})
t.Run("no images when none provided", func(t *testing.T) {
mock.CompletionResponse.Content = "No images"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify no images in completion request
if len(mock.CompletionRequest.Images) != 0 {
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
})
}

View File

@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// Supports both single-stream and dual-stream architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
// Used for dual-stream architectures.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {

View File

@@ -21,8 +21,6 @@ import (
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -61,14 +59,11 @@ func main() {
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
@@ -166,60 +161,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:

View File

@@ -177,20 +177,6 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
})
}
// GenerateImageWithInputs implements runner.ImageEditModel interface.
// It generates an image conditioned on the provided input images for image editing.
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
InputImages: inputImages,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048

View File

@@ -1,87 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,367 +0,0 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,218 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,135 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -1,174 +0,0 @@
//go:build mlx
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

View File

@@ -1,868 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Qwen-Image transformer configuration
type TransformerConfig struct {
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
NHeads int32 `json:"num_attention_heads"` // 24
HeadDim int32 `json:"attention_head_dim"` // 128
NLayers int32 `json:"num_layers"` // 60
InChannels int32 `json:"in_channels"` // 64
OutChannels int32 `json:"out_channels"` // 16
PatchSize int32 `json:"patch_size"` // 2
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
NormEps float32 `json:"norm_eps"` // 1e-6
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
GuidanceEmbeds bool `json:"guidance_embeds"` // false
}
// defaultTransformerConfig returns config for Qwen-Image transformer
func defaultTransformerConfig() *TransformerConfig {
return &TransformerConfig{
HiddenDim: 3072, // 24 * 128
NHeads: 24,
HeadDim: 128,
NLayers: 60,
InChannels: 64,
OutChannels: 16,
PatchSize: 2,
JointAttentionDim: 3584,
NormEps: 1e-6,
AxesDimsRope: []int32{16, 56, 56},
GuidanceEmbeds: false,
}
}
// TimestepEmbedder creates timestep embeddings
type TimestepEmbedder struct {
Linear1Weight *mlx.Array // [256, hidden_dim]
Linear1Bias *mlx.Array
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
Linear2Bias *mlx.Array
}
// newTimestepEmbedder creates a timestep embedder from weights
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
if err != nil {
return nil, err
}
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
if err != nil {
return nil, err
}
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
if err != nil {
return nil, err
}
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
if err != nil {
return nil, err
}
return &TimestepEmbedder{
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
Linear1Bias: linear1Bias,
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
Linear2Bias: linear2Bias,
}, nil
}
// Forward computes timestep embeddings
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
half := int32(128) // embedding_dim / 2
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
tExpanded := mlx.ExpandDims(t, 1)
args := mlx.Mul(tExpanded, freqsArr)
args = mlx.MulScalar(args, 1000.0) // scale
// [cos, sin] (flip_sin_to_cos=True)
sinArgs := mlx.Sin(args)
cosArgs := mlx.Cos(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
// MLP: linear1 -> silu -> linear2
h := mlx.Linear(embedding, te.Linear1Weight)
h = mlx.Add(h, te.Linear1Bias)
h = mlx.SiLU(h)
h = mlx.Linear(h, te.Linear2Weight)
h = mlx.Add(h, te.Linear2Bias)
return h
}
// JointAttention implements dual-stream joint attention
type JointAttention struct {
// Image projections
ToQ *mlx.Array
ToQB *mlx.Array
ToK *mlx.Array
ToKB *mlx.Array
ToV *mlx.Array
ToVB *mlx.Array
ToOut *mlx.Array
ToOutB *mlx.Array
NormQ *mlx.Array
NormK *mlx.Array
// Text (added) projections
AddQProj *mlx.Array
AddQProjB *mlx.Array
AddKProj *mlx.Array
AddKProjB *mlx.Array
AddVProj *mlx.Array
AddVProjB *mlx.Array
ToAddOut *mlx.Array
ToAddOutB *mlx.Array
NormAddQ *mlx.Array
NormAddK *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newJointAttention creates a joint attention layer
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
return &JointAttention{
ToQ: mlx.Transpose(toQ, 1, 0),
ToQB: toQB,
ToK: mlx.Transpose(toK, 1, 0),
ToKB: toKB,
ToV: mlx.Transpose(toV, 1, 0),
ToVB: toVB,
ToOut: mlx.Transpose(toOut, 1, 0),
ToOutB: toOutB,
NormQ: normQ,
NormK: normK,
AddQProj: mlx.Transpose(addQProj, 1, 0),
AddQProjB: addQProjB,
AddKProj: mlx.Transpose(addKProj, 1, 0),
AddKProjB: addKProjB,
AddVProj: mlx.Transpose(addVProj, 1, 0),
AddVProjB: addVProjB,
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
ToAddOutB: toAddOutB,
NormAddQ: normAddQ,
NormAddK: normAddK,
NHeads: cfg.NHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
}, nil
}
// Forward computes joint attention
// img: [B, L_img, D], txt: [B, L_txt, D]
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
D := imgShape[2]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// === Image Q/K/V ===
imgFlat := mlx.Reshape(img, B*Limg, D)
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
// QK norm (RMSNorm per head)
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
// Apply RoPE
if imgFreqs != nil {
qImg = applyRoPE(qImg, imgFreqs)
kImg = applyRoPE(kImg, imgFreqs)
}
// === Text Q/K/V ===
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
if txtFreqs != nil {
qTxt = applyRoPE(qTxt, txtFreqs)
kTxt = applyRoPE(kTxt, txtFreqs)
}
// Concatenate for joint attention: [txt, img] order
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
// Transpose to [B, nheads, L, head_dim]
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
// SDPA
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
// Transpose back and split
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
// Output projections
outImg = mlx.Reshape(outImg, B*Limg, D)
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
outImg = mlx.Reshape(outImg, B, Limg, D)
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
return outImg, outTxt
}
// applyRoPE applies rotary embeddings using complex multiplication
// x: [B, L, nheads, head_dim]
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
halfDim := headDim / 2
// Reshape x to pairs: [B, L, nheads, half, 2]
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
// Extract real/imag parts
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
xReal = mlx.Squeeze(xReal, 4)
xImag = mlx.Squeeze(xImag, 4)
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
freqReal = mlx.Squeeze(freqReal, 4)
freqImag = mlx.Squeeze(freqImag, 4)
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
// Interleave back
outReal = mlx.ExpandDims(outReal, 4)
outImag = mlx.ExpandDims(outImag, 4)
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
return mlx.Reshape(out, B, L, nheads, headDim)
}
// MLP implements GELU MLP (not GEGLU)
type MLP struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
OutWeight *mlx.Array
OutBias *mlx.Array
}
// newMLP creates a GELU MLP
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
outWeight, _ := weights.Get(prefix + ".net.2.weight")
outBias, _ := weights.Get(prefix + ".net.2.bias")
return &MLP{
ProjWeight: mlx.Transpose(projWeight, 1, 0),
ProjBias: projBias,
OutWeight: mlx.Transpose(outWeight, 1, 0),
OutBias: outBias,
}, nil
}
// Forward applies GELU MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
h = geluApprox(h)
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
}
// geluApprox implements approximate GELU
func geluApprox(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
inner = mlx.MulScalar(inner, sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// TransformerBlock is a single dual-stream transformer block
type TransformerBlock struct {
Attention *JointAttention
ImgMLP *MLP
TxtMLP *MLP
ImgModWeight *mlx.Array
ImgModBias *mlx.Array
TxtModWeight *mlx.Array
TxtModBias *mlx.Array
HiddenDim int32
NormEps float32
}
// newTransformerBlock creates a transformer block
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
attn, err := newJointAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
return &TransformerBlock{
Attention: attn,
ImgMLP: imgMLP,
TxtMLP: txtMLP,
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
ImgModBias: imgModBias,
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
TxtModBias: txtModBias,
HiddenDim: cfg.HiddenDim,
NormEps: cfg.NormEps,
}, nil
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
siluT := mlx.SiLU(temb)
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
imgModParts := splitMod6(imgMod, tb.HiddenDim)
txtModParts := splitMod6(txtMod, tb.HiddenDim)
// Pre-attention: norm + modulate
imgNorm := layerNormNoAffine(img, tb.NormEps)
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
txtNorm := layerNormNoAffine(txt, tb.NormEps)
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
// Joint attention
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
// Pre-MLP: norm + modulate
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
// MLP
mlpImg := tb.ImgMLP.Forward(imgNorm2)
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
return img, txt
}
// splitMod6 splits modulation into 6 parts each [B, 1, D]
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
shape := mod.Shape()
B := shape[0]
parts := make([]*mlx.Array, 6)
for i := int32(0); i < 6; i++ {
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
parts[i] = mlx.ExpandDims(part, 1)
}
return parts
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Qwen-Image transformer model
type Transformer struct {
Config *TransformerConfig
ImgIn *mlx.Array
ImgInBias *mlx.Array
TxtIn *mlx.Array
TxtInBias *mlx.Array
TxtNorm *mlx.Array
TEmbed *TimestepEmbedder
Layers []*TransformerBlock
NormOutWeight *mlx.Array
NormOutBias *mlx.Array
ProjOut *mlx.Array
ProjOutBias *mlx.Array
}
// Load loads the transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Qwen-Image transformer...")
cfg := defaultTransformerConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading input projections... ")
imgIn, _ := weights.Get("img_in.weight")
imgInBias, _ := weights.Get("img_in.bias")
txtIn, _ := weights.Get("txt_in.weight")
txtInBias, _ := weights.Get("txt_in.bias")
txtNorm, _ := weights.Get("txt_norm.weight")
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
m.ImgInBias = imgInBias
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
m.TxtInBias = txtInBias
m.TxtNorm = txtNorm
fmt.Println("✓")
fmt.Print(" Loading timestep embedder... ")
m.TEmbed, err = newTimestepEmbedder(weights)
if err != nil {
return fmt.Errorf("timestep embedder: %w", err)
}
fmt.Println("✓")
m.Layers = make([]*TransformerBlock, cfg.NLayers)
for i := int32(0); i < cfg.NLayers; i++ {
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
prefix := fmt.Sprintf("transformer_blocks.%d", i)
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
if err != nil {
return fmt.Errorf("layer %d: %w", i, err)
}
}
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
fmt.Print(" Loading output layers... ")
normOutWeight, _ := weights.Get("norm_out.linear.weight")
normOutBias, _ := weights.Get("norm_out.linear.bias")
projOut, _ := weights.Get("proj_out.weight")
projOutBias, _ := weights.Get("proj_out.bias")
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
m.NormOutBias = normOutBias
m.ProjOut = mlx.Transpose(projOut, 1, 0)
m.ProjOutBias = projOutBias
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadFromPath is a convenience function to load transformer from path
func LoadTransformerFromPath(path string) (*Transformer, error) {
m := &Transformer{}
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
return nil, err
}
return m, nil
}
// Forward runs the transformer
// img: [B, L_img, in_channels] patchified latents
// txt: [B, L_txt, joint_attention_dim] text embeddings
// t: [B] timesteps (0-1)
// imgFreqs, txtFreqs: RoPE frequencies
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
for _, layer := range tr.Layers {
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
}
// Final norm with modulation (AdaLayerNormContinuous)
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// ForwardWithCache runs the transformer with layer caching for speedup.
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between denoising steps, so we cache their
// outputs and reuse them on non-refresh steps.
//
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
// step: current denoising step (0-indexed)
// cacheInterval: refresh cache every N steps (e.g., 3)
// cacheLayers: number of shallow layers to cache (e.g., 15)
func (tr *Transformer) ForwardWithCache(
img, txt, t *mlx.Array,
imgFreqs, txtFreqs *mlx.Array,
stepCache *cache.StepCache,
step, cacheInterval, cacheLayers int,
) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
// Check if we should refresh the cache
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
for i, layer := range tr.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached outputs for shallow layers
imgH = stepCache.Get(i)
txtH = stepCache.Get2(i)
} else {
// Compute layer
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
// Cache shallow layers on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, imgH)
stepCache.Set2(i, txtH)
}
}
}
// Final norm with modulation (AdaLayerNormContinuous)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// RoPECache holds precomputed RoPE frequencies
type RoPECache struct {
ImgFreqs *mlx.Array // [L_img, head_dim]
TxtFreqs *mlx.Array // [L_txt, head_dim]
}
// PrepareRoPE computes RoPE for image and text sequences
// This matches Python's QwenEmbedRope with scale_rope=True
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := ComputeAxisFreqs(axesDims[0], theta)
freqsH := ComputeAxisFreqs(axesDims[1], theta)
freqsW := ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
// Image frequencies with scale_rope=True
imgLen := imgH * imgW
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
imgFreqsData := make([]float32, imgLen*headDim)
hHalf := imgH / 2
wHalf := imgW / 2
idx := int32(0)
for y := int32(0); y < imgH; y++ {
for x := int32(0); x < imgW; x++ {
// Frame = 0
for i := 0; i < len(freqsT)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
}
idx += int32(len(freqsT) * 2)
// Height: scale_rope pattern
hNegCount := imgH - hHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
}
}
idx += int32(len(freqsH) * 2)
// Width: scale_rope pattern
wNegCount := imgW - wHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
}
}
idx += int32(len(freqsW) * 2)
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies
maxVidIdx := max(hHalf, wHalf)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
halfDim := dim / 2
freqs := make([]float64, halfDim)
for i := int32(0); i < halfDim; i++ {
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
}
return freqs
}
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
table := make([][]float32, maxIdx)
for idx := int32(0); idx < maxIdx; idx++ {
var pos float64
if negative {
pos = float64(-maxIdx + int32(idx))
} else {
pos = float64(idx)
}
row := make([]float32, len(baseFreqs)*2)
for i, f := range baseFreqs {
angle := pos * f
row[i*2] = float32(math.Cos(angle))
row[i*2+1] = float32(math.Sin(angle))
}
table[idx] = row
}
return table
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// -> [B, pH, pW, C, 2, 2]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// -> [B, pH*pW, C*4]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
channels := shape[2] / (patchSize * patchSize)
pH := H / patchSize
pW := W / patchSize
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// -> [B, C, pH, 2, pW, 2]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// -> [B, C, H, W]
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
// Add temporal dimension for VAE: [B, C, 1, H, W]
return mlx.ExpandDims(x, 2)
}

View File

@@ -1,119 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestTransformerConfig tests configuration invariants.
func TestTransformerConfig(t *testing.T) {
cfg := defaultTransformerConfig()
// Property: hidden_dim = n_heads * head_dim
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
}
// Property: axes_dims_rope sums to head_dim
var ropeSum int32
for _, d := range cfg.AxesDimsRope {
ropeSum += d
}
if ropeSum != cfg.HeadDim {
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
}
// Property: in_channels = out_channels * patch_size^2
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
if cfg.InChannels != expectedIn {
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
}
}
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
func TestTransformerRoPE(t *testing.T) {
cfg := defaultTransformerConfig()
// Test with small image dimensions
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
txtLen := int32(5)
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Verify shapes: [seq_len, head_dim]
imgSeqLen := imgH * imgW
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
}
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
}
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
}
// Verify values are finite
imgData := ropeCache.ImgFreqs.Data()
for i := 0; i < min(100, len(imgData)); i++ {
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
break
}
}
}
// TestTransformerForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestTransformerForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
transformer := &Transformer{}
if err := transformer.Load(weightsPath); err != nil {
t.Fatalf("Failed to load transformer: %v", err)
}
mlx.Keep(mlx.Collect(transformer)...)
cfg := transformer.Config
// Small test inputs
batchSize := int32(1)
imgH, imgW := int32(4), int32(4)
imgSeqLen := imgH * imgW
txtSeqLen := int32(5)
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
// Forward pass
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(out)
// Verify output shape: [batch, img_seq_len, in_channels]
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
gotShape := out.Shape()
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
}
// Verify output is finite
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,854 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
// Input is channels-last: [B, T, H, W, C]
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// RMSNorm3D applies RMS normalization over channels
// Works with channels-last [B, T, H, W, C] format
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
// Use h as working variable, keep x intact for residual (caller will free x)
// Conv handles its own pools, so we just need pools for non-conv operations
var h *mlx.Array
// Keep x so it survives Eval() cleanup - needed for residual connection
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1 (handles its own pools)
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2 (handles its own pools)
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection (shortcut handles its own pools if present)
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V: [B*T, H*W, 3*C]
// Weight is [outC, inC] or [outC, inC, 1, 1]
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0) // [inC, outC]
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// out: [B*T, 1, H*W, C]
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0) // [inC, outC]
out = mlx.Linear(out, p) // [B*T, H*W, C]
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block with staged memory management
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
// ResBlocks handle their own pools
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Upsampler handles its own pools
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
// Uses staged pools to reduce peak memory during 2x upsampling
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block of decoder
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
// Each component handles its own pools; we just free inputs
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading Qwen-Image VAE decoder...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading post_quant_conv... ")
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
m.PostQuantConv = postQuantConv
fmt.Println("✓")
fmt.Print(" Loading conv_in... ")
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
m.ConvIn = convIn
fmt.Println("✓")
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
fmt.Print(" Loading mid_block... ")
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
m.MidBlock = midBlock
fmt.Println("✓")
// Up blocks (reversed dim_mult)
fmt.Print(" Loading up_blocks... ")
numUpBlocks := len(cfg.DimMult)
m.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
m.UpBlocks[i] = upBlock
}
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
fmt.Print(" Loading output layers... ")
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
m.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
m.ConvOut = convOut
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
m := &VAEDecoder{}
if err := m.Load(filepath.Join(path, "vae")); err != nil {
return nil, err
}
return m, nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] normalized latents
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Stage 1a: Denormalize and transpose
{
z = vae.Denormalize(z)
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// Stage 1b: PostQuantConv (handles its own pools)
x = vae.PostQuantConv.Forward(z)
z.Free()
// Stage 1c: ConvIn (handles its own pools)
{
prev := x
x = vae.ConvIn.Forward(x)
prev.Free()
}
// Stage 2: Mid block (handles its own pools)
x = vae.MidBlock.Forward(x)
// Stage 3: Up blocks (each handles its own pools)
for _, upBlock := range vae.UpBlocks {
x = upBlock.Forward(x)
}
// Stage 4a: NormOut + silu
{
prev := x
x = vae.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Stage 4b: ConvOut (handles its own pools)
{
prev := x
x = vae.ConvOut.Forward(x)
prev.Free()
}
// Stage 4c: Post-processing
{
prev := x
// Clamp to [-1, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
// Convert back from channels-last to channels-first
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// Denormalize reverses the normalization applied during encoding
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
mean = mlx.ToBFloat16(mean)
std = mlx.ToBFloat16(std)
return mlx.Add(mlx.Mul(z, std), mean)
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
}
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[2]
W := shape[3]
x = mlx.Transpose(x, 0, 2, 3, 1)
x = mlx.Reshape(x, B*H*W, shape[1])
wShape := weight.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(weight, wShape[0], wShape[1])
} else {
w = weight
}
w = mlx.Transpose(w, 1, 0)
out := mlx.Linear(x, w)
outC := w.Dim(1)
out = mlx.Reshape(out, B, H, W, outC)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
x = pad2D(x, 1, 1, 1, 1)
return conv2D(x, weight, 1, 1)
}
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
x = mlx.Transpose(x, 0, 2, 3, 1)
w = mlx.Transpose(w, 0, 2, 3, 1)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := w.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := extractPatches2D(x, kH, kW, strideH, strideW)
wFlat := mlx.Reshape(w, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
out = mlx.Reshape(out, B, outH, outW, Cout)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * strideH
startW := j * strideW
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[2]
W := shape[3]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 2)
x = mlx.Take(x, colIdx, 3)
return x
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create repeat indices for rows
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
// Create repeat indices for columns
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
// Take along H (axis 1) then W (axis 2)
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
// weight: [outC, kH, kW, inC] (MLX channels-last format)
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
// stride=1, padding=0 (we already padded manually)
return mlx.Conv2d(x, weight, 1, 0)
}

View File

@@ -1,114 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestVAEConfig tests configuration invariants.
func TestVAEConfig(t *testing.T) {
cfg := defaultVAEConfig()
// Property: latents_mean and latents_std have z_dim elements
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
}
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
}
// Property: dim_mult defines 4 stages
if len(cfg.DimMult) != 4 {
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
}
// Property: temperal_downsample has 3 elements (for 3 transitions)
if len(cfg.TemperalDownsample) != 3 {
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
}
}
// TestVAELatentsNormalization tests the latent denormalization values.
func TestVAELatentsNormalization(t *testing.T) {
cfg := defaultVAEConfig()
// Verify latents_std values are all positive
for i, std := range cfg.LatentsStd {
if std <= 0 {
t.Errorf("latents_std[%d] should be positive: %v", i, std)
}
}
// Verify values are in reasonable range (from actual model)
for i, mean := range cfg.LatentsMean {
if math.Abs(float64(mean)) > 5 {
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
}
}
for i, std := range cfg.LatentsStd {
if std > 10 {
t.Errorf("latents_std[%d] seems too large: %v", i, std)
}
}
}
// TestVAEDecoderForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestVAEDecoderForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/vae"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
vae := &VAEDecoder{}
if err := vae.Load(weightsPath); err != nil {
t.Fatalf("Failed to load VAE decoder: %v", err)
}
mlx.Keep(mlx.Collect(vae)...)
// Small test input: [B, C, T, H, W]
// After 4 upsampling stages (2x each), H/W multiply by 16
batchSize := int32(1)
channels := int32(16)
frames := int32(1)
latentH := int32(4)
latentW := int32(4)
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
// Decode
out := vae.Decode(latents)
mlx.Eval(out)
// Verify output shape: [B, 3, T, H*16, W*16]
outShape := out.Shape()
if outShape[0] != batchSize {
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
}
if outShape[1] != 3 {
t.Errorf("channels: got %d, want 3", outShape[1])
}
if outShape[2] != frames {
t.Errorf("frames: got %d, want %d", outShape[2], frames)
}
expectedH := latentH * 16 // 4 stages of 2x upsampling
expectedW := latentW * 16
if outShape[3] != expectedH || outShape[4] != expectedW {
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
outShape[3], outShape[4], expectedH, expectedW)
}
// Verify output is in valid range (should be clamped to [0, 1] by decode)
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,682 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution (or 2D if weight is 4D)
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape()
// Handle both 5D (3D conv) and 4D (2D conv) weights
if len(shape) == 4 {
// 2D conv: [O, I, kH, kW] - need to apply per-frame
return c.forward2D(x)
}
// 3D conv: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := c.Weight.Shape() // [O, I, kH, kW]
kernelH := wShape[2]
kernelW := wShape[3]
outC := wShape[0]
padH := kernelH / 2
padW := kernelW / 2
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad spatially
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
// Apply 2D conv
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = mlx.Conv2d(x, weight, 1, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Get output spatial dims
outH := H
outW := W
// Reshape back to [B, T, H, W, C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// RMSNorm3D applies RMS normalization over channels
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0)
qkv := mlx.Linear(x, w)
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0)
out = mlx.Linear(out, p)
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
return mlx.Conv2d(x, weight, 1, 0)
}
// conv2DStrided applies conv with stride > 1 using manual patch extraction
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := weight.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := extractPatches2DStrided(x, kH, kW, stride)
wFlat := mlx.Reshape(weight, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outH, outW, Cout)
}
// conv3DStrided applies 3D conv with strides using manual patch extraction
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
// strideT, strideH, strideW are the strides for each dimension
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
wShape := weight.Shape()
Cout := wShape[0]
// I := wShape[1]
kT := wShape[2]
kH := wShape[3]
kW := wShape[4]
// For temporal: if T < kT, we need to repeat frames temporally
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
// Python Qwen2.5-VL duplicates the frame, not zero-pads
if T < kT {
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
T = kT
}
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
// Extract 3D patches in [C, T, H, W] order to match Python
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outT, outH, outW, Cout)
}
// extractPatches3DStrided extracts 3D patches with given strides
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
numPatches := outT * outH * outW
patches := make([]*mlx.Array, numPatches)
idx := 0
for t := int32(0); t < outT; t++ {
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startT := t * strideT
startH := i * strideH
startW := j * strideW
// Extract patch: [B, kT, kH, kW, C]
patch := mlx.Slice(x,
[]int32{0, startT, startH, startW, 0},
[]int32{B, startT + kT, startH + kH, startW + kW, C})
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
// Flatten to [B, C*T*H*W]
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
patches[idx] = patch
idx++
}
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
}
// extractPatches2DStrided extracts patches with given stride
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * stride
startW := j * stride
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}

View File

@@ -1,475 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
_ "golang.org/x/image/webp"
)
// loadImageFile loads an image from disk
func loadImageFile(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return img, nil
}
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
pixels := make([]float32, width*height*3)
idx := 0
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
r, g, b, _ := img.At(x, y).RGBA()
pixels[idx] = float32(r) / 65535.0
pixels[idx+1] = float32(g) / 65535.0
pixels[idx+2] = float32(b) / 65535.0
idx += 3
}
}
return pixels
}
// normalizeImageNet applies ImageNet normalization to an image tensor
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
return mlx.Div(mlx.Sub(arr, mean), std)
}
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch dimension [1, C, H, W]
arr = mlx.ExpandDims(arr, 0)
// Convert to bf16
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// clampFloat clamps a value to [0, 255] and returns uint8
func clampFloat(v, weightSum float64) uint8 {
v /= weightSum
if v < 0 {
v = 0
}
if v > 255 {
v = 255
}
return uint8(math.Round(v))
}
// ImageDims holds dimensions for a preprocessed image
type ImageDims struct {
// Original image dimensions
OrigW, OrigH int32
// Condition image dimensions (for vision encoder)
CondW, CondH int32
// VAE image dimensions
VaeW, VaeH int32
// Latent dimensions (VAE dims / vae_scale_factor)
LatentW, LatentH int32
// Patch dimensions (latent dims / patch_size)
PatchW, PatchH int32
}
// ProcessorConfig holds image processor configuration
type ProcessorConfig struct {
// Condition image size (target pixel area for vision encoder input)
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
// Pipeline resizes image to this area before passing to encode_prompt
ConditionImageSize int32
// VAE image size (target pixel area)
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
VAEImageSize int32
// Image normalization (ImageNet stats for vision encoder)
ImageMean []float32
ImageStd []float32
}
// defaultProcessorConfig returns default processor config
func defaultProcessorConfig() *ProcessorConfig {
return &ProcessorConfig{
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// Processor handles image preprocessing for Qwen-Image-Edit
type Processor struct {
Config *ProcessorConfig
}
// Load loads the processor config
func (p *Processor) Load(path string) error {
p.Config = defaultProcessorConfig()
return nil
}
// LoadAndPreprocess loads an image and preprocesses it for both paths
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, err
}
bounds := img.Bounds()
origW := bounds.Dx()
origH := bounds.Dy()
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Preprocess for condition (vision encoder) - two-step resize
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
return condImage, vaeImage, nil
}
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
// Used by edit_plus pipeline for multi-image input
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImage converts image to tensor for vision encoder
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
resized := resizeImageBicubic(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
if normalize {
arr = p.normalizeImageNet(arr)
}
return prepareImageTensor(arr)
}
// preprocessImageForVAE converts image to tensor for VAE encoding
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
// Scale to [-1, 1]: arr * 2 - 1
arr = mlx.MulScalar(arr, 2.0)
arr = mlx.AddScalar(arr, -1.0)
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch and temporal dimensions [1, C, 1, H, W]
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// smartResize implements Python Qwen2VL processor's smart_resize logic
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
// Round to factor
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
// Ensure minimum factor size
if hBar < factor {
hBar = factor
}
if wBar < factor {
wBar = factor
}
// Check pixel constraints
total := hBar * wBar
if total > maxPixels {
// Scale down
beta := math.Sqrt(float64(maxPixels) / float64(total))
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
} else if total < minPixels {
// Scale up
beta := math.Sqrt(float64(minPixels) / float64(total))
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
// calculateDimensions calculates width and height for a target area while maintaining ratio
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
width := math.Sqrt(float64(targetArea) * ratio)
height := width / ratio
m := float64(multiple)
width = math.Round(width/m) * m
height = math.Round(height/m) * m
// Ensure minimum dimensions
if width < m {
width = m
}
if height < m {
height = m
}
return int32(width), int32(height)
}
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
func resizeImageLanczos(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
dst := image.NewRGBA(image.Rect(0, 0, width, height))
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
lanczos3 := &draw.Kernel{
Support: 3.0,
At: func(t float64) float64 {
if t == 0 {
return 1.0
}
if t < 0 {
t = -t
}
if t >= 3.0 {
return 0.0
}
// sinc(t) * sinc(t/3)
piT := math.Pi * t
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
},
}
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
return dst
}
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
// Uses separable interpolation with PIL's coordinate mapping for exact match
func resizeImageBicubic(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
srcW := bounds.Dx()
srcH := bounds.Dy()
// Convert to RGBA if needed
var src *image.RGBA
if rgba, ok := img.(*image.RGBA); ok {
src = rgba
} else {
src = image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
src.Set(x, y, img.At(x, y))
}
}
}
// Keys cubic with a=-0.5 (PIL BICUBIC)
cubic := func(x float64) float64 {
if x < 0 {
x = -x
}
if x < 1 {
return 1.5*x*x*x - 2.5*x*x + 1
}
if x < 2 {
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
}
return 0
}
// Horizontal pass: srcW -> width, keep srcH rows
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
for y := 0; y < srcH; y++ {
for dstX := 0; dstX < width; dstX++ {
// PIL coordinate mapping: center-to-center
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
baseX := int(math.Floor(srcXf))
var sumR, sumG, sumB, sumA, weightSum float64
for i := -1; i <= 2; i++ {
sx := baseX + i
if sx < 0 {
sx = 0
}
if sx >= srcW {
sx = srcW - 1
}
w := cubic(math.Abs(srcXf - float64(baseX+i)))
c := src.RGBAAt(sx, y)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
temp.SetRGBA(dstX, y, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
// Vertical pass: srcH -> height
dst := image.NewRGBA(image.Rect(0, 0, width, height))
for x := 0; x < width; x++ {
for dstY := 0; dstY < height; dstY++ {
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
baseY := int(math.Floor(srcYf))
var sumR, sumG, sumB, sumA, weightSum float64
for j := -1; j <= 2; j++ {
sy := baseY + j
if sy < 0 {
sy = 0
}
if sy >= srcH {
sy = srcH - 1
}
w := cubic(math.Abs(srcYf - float64(baseY+j)))
c := temp.RGBAAt(x, sy)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
dst.SetRGBA(x, dstY, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
return dst
}
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
const vaeScaleFactor int32 = 8
const patchSize int32 = 2
condImages := make([]*mlx.Array, len(imagePaths))
vaeImages := make([]*mlx.Array, len(imagePaths))
dims := make([]ImageDims, len(imagePaths))
for i, imagePath := range imagePaths {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
}
bounds := img.Bounds()
origW := int32(bounds.Dx())
origH := int32(bounds.Dy())
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Calculate derived dimensions
latentW := vaeW / vaeScaleFactor
latentH := vaeH / vaeScaleFactor
patchW := latentW / patchSize
patchH := latentH / patchSize
dims[i] = ImageDims{
OrigW: origW,
OrigH: origH,
CondW: condW,
CondH: condH,
VaeW: vaeW,
VaeH: vaeH,
LatentW: latentW,
LatentH: latentH,
PatchW: patchW,
PatchH: patchH,
}
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
}
return condImages, vaeImages, dims, nil
}

View File

@@ -1,625 +0,0 @@
//go:build mlx
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
// It reuses components from qwen_image where possible.
package qwen_image_edit
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
}
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
Processor *Processor // Image processor for vision encoder
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
VAE *VAE // Combined encoder + decoder
}
// Load loads the Qwen-Image-Edit model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image-Edit model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer from processor directory
fmt.Print(" Loading tokenizer... ")
processorPath := filepath.Join(modelPath, "processor")
tok, err := tokenizer.Load(processorPath)
if err != nil {
// Fallback to tokenizer directory
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err = tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
}
m.Tokenizer = tok
fmt.Println("✓")
// Load processor (image preprocessing config)
fmt.Print(" Loading processor... ")
m.Processor = &Processor{}
if err := m.Processor.Load(processorPath); err != nil {
return fmt.Errorf("processor: %w", err)
}
fmt.Println("✓")
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
m.TextEncoder = &qwen_image.Qwen25VL{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer (reuse qwen_image)
m.Transformer = &qwen_image.Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE (encoder + decoder)
m.VAE = &VAE{}
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE: %w", err)
}
mlx.Eval(mlx.Collect(m.VAE)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Edit edits an image based on a text prompt.
// inputImagePath: path to input image
// prompt: text description of desired edit
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// EditFromConfig edits images using the unified config struct.
// Accepts one or more input images.
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
if len(inputImagePaths) == 0 {
return nil, fmt.Errorf("no input images provided")
}
start := time.Now()
result, err := m.edit(inputImagePaths, cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// EditImage implements model.ImageEditModel interface.
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
}
// EditMultiImage edits using multiple source images.
// This matches diffusers' QwenImageEditPlusPipeline behavior.
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
return m.EditFromConfig(inputImagePaths, cfg)
}
// edit is the internal editing pipeline that handles one or more images.
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// Load and preprocess all input images
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
if err != nil {
return nil, fmt.Errorf("preprocess images: %w", err)
}
for _, img := range condImages {
mlx.Keep(img)
}
for _, img := range vaeImages {
mlx.Keep(img)
}
mlx.Eval(append(condImages, vaeImages...)...)
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
vaeScaleFactor := int32(8)
// Output dimensions - if not specified, use first input image dimensions
if cfg.Width <= 0 {
cfg.Width = inputDims[0].VaeW
}
if cfg.Height <= 0 {
cfg.Height = inputDims[0].VaeH
}
// Output (noise) latent dimensions
outLatentH := cfg.Height / vaeScaleFactor
outLatentW := cfg.Width / vaeScaleFactor
outPH := outLatentH / tcfg.PatchSize
outPW := outLatentW / tcfg.PatchSize
noiseSeqLen := outPH * outPW
imgSeqLen := noiseSeqLen
// Encode prompt with all images for conditioning
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding prompt: %w", err)
}
mlx.Keep(posEmb)
mlx.Eval(posEmb)
var negEmb *mlx.Array
if useCFG {
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding negative prompt: %w", err)
}
mlx.Keep(negEmb)
mlx.Eval(negEmb)
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
for i, vaeImage := range vaeImages {
imageLatents := m.VAE.Encode(vaeImage)
imageLatents = m.VAE.Normalize(imageLatents)
imageLatents2D := mlx.Squeeze(imageLatents, 2)
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
mlx.Keep(packed)
mlx.Eval(packed)
allImageLatentsPacked[i] = packed
}
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
mlx.Keep(imageLatentsPacked)
mlx.Eval(imageLatentsPacked)
// Scheduler
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
// Init noise latents in packed format
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
mlx.Eval(latents)
// RoPE cache
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Denoising loop
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
mlx.Eval(timestep)
latents2D := mlx.Squeeze(latents, 2)
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
}
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
}
// Free denoising temporaries
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()
// Decode latents
decoded := m.decodeAndPostprocess(latents)
latents.Free()
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
// This prevents CFG from inflating magnitude too much.
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
// Upcast to float32 for precision
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
// CFG: pred = neg + scale * (pos - neg)
diff := mlx.Sub(posF32, negF32)
scaledDiff := mlx.MulScalar(diff, scale)
combPred := mlx.Add(negF32, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
mlx.Eval(output)
return mlx.ToBFloat16(output)
}
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
latents = m.VAE.Denormalize(latents)
decoded := m.VAE.Decode(latents)
// Post-process: squeeze temporal dim and rescale to [0, 1]
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
mlx.Eval(decoded)
return decoded
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
// Handles single or multiple input images with different resolutions.
//
// Parameters:
// - outPH, outPW: output patch dimensions (noise latent resolution)
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
// - txtLen: text sequence length
// - axesDims: RoPE axis dimensions [16, 56, 56]
//
// Returns RoPE cache where:
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
// - Following positions are for each input image (interpolated from output res)
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
// Helper to compute RoPE for a single position at output resolution with scale_rope
computePosFreqs := func(framePos, y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame position
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = posFreqsT[framePos][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Helper to compute RoPE for frame -1 (used for last condition image)
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
computeNegFrameFreqs := func(y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame -1: use last row of negative frame frequencies
negFrameIdx := maxIdx - 1
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = negFreqsT[negFrameIdx][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Total image sequence length: noise + all input images
noiseSeqLen := outPH * outPW
totalImgLen := noiseSeqLen
for _, dims := range inputDims {
totalImgLen += dims.PatchH * dims.PatchW
}
imgFreqsData := make([]float32, totalImgLen*headDim)
idx := int32(0)
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
for y := int32(0); y < outPH; y++ {
for x := int32(0); x < outPW; x++ {
row := computePosFreqs(0, y, x)
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
// For multiple images: Python uses frame -1 for the LAST condition image
// (_compute_condition_freqs), positive indices for others.
numImages := len(inputDims)
lastImgIdx := numImages - 1
for imgIdx, dims := range inputDims {
inPH := dims.PatchH
inPW := dims.PatchW
// Determine frame index for this image
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
// Map each input position to an output position using linear interpolation
for y := int32(0); y < inPH; y++ {
for x := int32(0); x < inPW; x++ {
// Interpolate: map input (y, x) to output grid position
// This is the key fix from DiffSynth's forward_sampling
var yOut, xOut int32
if inPH == 1 {
yOut = 0
} else {
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
yOut = y * (outPH - 1) / (inPH - 1)
}
if inPW == 1 {
xOut = 0
} else {
xOut = x * (outPW - 1) / (inPW - 1)
}
var row []float32
if useNegFrame {
// Last image in multi-image uses frame -1
row = computeNegFrameFreqs(yOut, xOut)
} else {
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
frameIdx := int32(imgIdx + 1)
row = computePosFreqs(frameIdx, yOut, xOut)
}
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies - start after max video index
maxVidIdx := max(outPH/2, outPW/2)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &qwen_image.RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}

View File

@@ -1,249 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"math"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000)
// Expected values from Python:
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
expectedFreqsT := []float64{
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
}
expectedFreqsH_first4 := []float64{
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
}
expectedFreqsH_last4 := []float64{
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
}
// Test temporal frequencies (dim=16)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
if len(freqsT) != 8 {
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
}
for i, expected := range expectedFreqsT {
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
}
}
// Test height/width frequencies (dim=56)
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
if len(freqsH) != 28 {
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
}
for i, expected := range expectedFreqsH_first4 {
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
}
}
for i, expected := range expectedFreqsH_last4 {
idx := 24 + i // last 4 of 28
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
}
}
}
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
func TestMakeFreqTable(t *testing.T) {
theta := float64(10000)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
maxIdx := int32(4096)
// Test positive table
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
// Position 0 should give cos=1, sin=0 for all frequencies
for i := 0; i < len(freqsT)*2; i += 2 {
if posTable[0][i] != 1.0 {
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
}
if posTable[0][i+1] != 0.0 {
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
}
}
// Position 1, first frequency (1.0): angle = 1*1 = 1
// cos(1) = 0.5403, sin(1) = 0.8415
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
}
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
}
// Test negative table
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
// negTable[4095] corresponds to position -1
// cos(-1) = cos(1), sin(-1) = -sin(1)
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
}
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
}
// negTable[4094] corresponds to position -2
// cos(-2) = cos(2), sin(-2) = -sin(2)
cos2 := math.Cos(2.0)
sin2 := math.Sin(2.0)
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
}
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
}
}
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
func TestPrepareRoPE_QwenImage(t *testing.T) {
if !mlx.GPUIsAvailable() {
t.Skip("GPU not available")
}
mlx.SetDefaultDeviceCPU()
// 4x4 patch grid, single image
imgH, imgW := int32(4), int32(4)
txtLen := int32(5)
axesDims := []int32{16, 56, 56}
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
// Check shapes
imgShape := cache.ImgFreqs.Shape()
if imgShape[0] != 16 { // 4*4 patches
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
}
// For single image (frame=0), all temporal values should be cos=1, sin=0
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
mlx.Eval(imgFreqsCPU)
imgData := imgFreqsCPU.Data()
// Check first 16 values of patch 0 (temporal cos/sin pairs)
for i := 0; i < 16; i += 2 {
cosVal := imgData[i]
sinVal := imgData[i+1]
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
}
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
}
}
cache.ImgFreqs.Free()
cache.TxtFreqs.Free()
}
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
func TestScaleRopePositions(t *testing.T) {
// For a 4x4 grid with scale_rope=True:
// hHalf = 2, wHalf = 2
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
//
// Height positions:
// y=0: -(4-2) + 0 = -2
// y=1: -(4-2) + 1 = -1
// y=2: 2 - 2 = 0
// y=3: 3 - 2 = 1
//
// Same for width
pH, pW := int32(4), int32(4)
hHalf := pH / 2
wHalf := pW / 2
hNegCount := pH - hHalf
wNegCount := pW - wHalf
expectedH := []int32{-2, -1, 0, 1}
expectedW := []int32{-2, -1, 0, 1}
for y := int32(0); y < pH; y++ {
var hPos int32
if y < hNegCount {
hPos = -(pH - hHalf) + y
} else {
hPos = y - hNegCount
}
if hPos != expectedH[y] {
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
}
}
for x := int32(0); x < pW; x++ {
var wPos int32
if x < wNegCount {
wPos = -(pW - wHalf) + x
} else {
wPos = x - wNegCount
}
if wPos != expectedW[x] {
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
}
}
}
// TestRoPEHeadDimensions verifies the head dimension breakdown
func TestRoPEHeadDimensions(t *testing.T) {
// axes_dims_rope = [16, 56, 56]
// Each dimension uses half the values for frequencies
// So we get: 8 + 28 + 28 = 64 frequency values
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
axesDims := []int32{16, 56, 56}
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
expectedHeadDim := expectedFreqs * 2
if expectedFreqs != 64 {
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
}
if expectedHeadDim != 128 {
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
}
// This should match the transformer's attention head dimension
// hidden_size = 3072, num_heads = 24
// head_dim = 3072 / 24 = 128
}

View File

@@ -1,642 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// VAE is the full VAE with encoder and decoder
type VAE struct {
Config *VAEConfig
Encoder *VAEEncoder
Decoder *VAEDecoder
}
// Load loads the VAE from a directory
func (m *VAE) Load(path string) error {
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load weights as f32 for quality (matches Python default behavior)
// VAE decoder precision is critical for final image quality
fmt.Print(" Loading weights as f32... ")
if err := weights.Load(mlx.DtypeFloat32); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
// Load encoder
fmt.Print(" Loading encoder... ")
m.Encoder = &VAEEncoder{}
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
fmt.Println("✓")
// Load decoder
fmt.Print(" Loading decoder... ")
m.Decoder = &VAEDecoder{}
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("decoder: %w", err)
}
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor in [-1, 1] range
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
return m.Encoder.Encode(x)
}
// Decode decodes latents to image
// z: [B, C, T, H, W] latents (denormalized)
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
return m.Decoder.Decode(z)
}
// Normalize applies latent normalization
// Input z should be f32 (from VAE encoder), output is f32 for transformer
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
// Mean/std are f32, will match z dtype through broadcasting
return mlx.Div(mlx.Sub(z, mean), std)
}
// Denormalize reverses latent normalization
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
// Convert latents to f32 for VAE decoder quality
z = mlx.AsType(z, mlx.DtypeFloat32)
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
return mlx.Add(mlx.Mul(z, std), mean)
}
// VAEEncoder is the encoder part of the VAE
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
// - Blocks 0,1: ResBlocks (base_dim)
// - Block 2: Downsample
// - Blocks 3,4: ResBlocks (base_dim*2)
// - Block 5: Downsample + temporal
// - Blocks 6,7: ResBlocks (base_dim*4)
// - Block 8: Downsample + temporal
// - Blocks 9,10: ResBlocks (base_dim*4)
type VAEEncoder struct {
Config *VAEConfig
ConvIn *CausalConv3d
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
MidBlock *MidBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
QuantConv *CausalConv3d
}
// EncoderBlock is either a ResBlock or a Downsample
type EncoderBlock interface {
Forward(x *mlx.Array) *mlx.Array
IsDownsample() bool
}
// EncoderResBlock wraps ResBlock
type EncoderResBlock struct {
*ResBlock
}
func (b *EncoderResBlock) IsDownsample() bool { return false }
// EncoderDownsample is a downsample layer
type EncoderDownsample struct {
Resample *CausalConv3d
TimeConv *CausalConv3d // Optional temporal downsample
}
func (d *EncoderDownsample) IsDownsample() bool { return true }
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
// Spatial downsample with stride 2
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
x = d.forwardSpatialDownsample(x)
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
// The Python forward checks: if feat_cache is not None ... then use time_conv
// Since we don't support streaming, we skip time_conv entirely.
return x
}
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := d.Resample.Weight.Shape()
outC := wShape[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
// Apply 2D conv with stride 2
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = conv2DStrided(x, weight, 2)
if d.Resample.Bias != nil {
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Output dims after stride 2: (H+1)/2, (W+1)/2
outH := (H + 1) / 2
outW := (W + 1) / 2
// Reshape back to [B, T, H', W', C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// loadFromWeights loads the encoder from pre-loaded weights
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
e.Config = cfg
// Conv in
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
if err != nil {
return err
}
e.ConvIn = convIn
// Encoder uses flat block structure:
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
e.Blocks = make([]EncoderBlock, 0, 11)
// Track dimensions
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
blockIdx := 0
for stage := 0; stage < len(cfg.DimMult); stage++ {
inDim := cfg.BaseDim
if stage > 0 {
inDim = dims[stage-1]
}
outDim := dims[stage]
// ResBlocks for this stage (num_res_blocks per stage)
for r := int32(0); r < cfg.NumResBlocks; r++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
currentInDim := inDim
if r > 0 {
currentInDim = outDim
}
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
if err != nil {
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, block)
blockIdx++
}
// Downsample after each stage except the last
if stage < len(cfg.DimMult)-1 {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
if err != nil {
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, down)
blockIdx++
}
}
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
if err != nil {
return err
}
e.MidBlock = midBlock
// Norm out
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
if err != nil {
return err
}
e.NormOut = normOut
// Conv out
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
if err != nil {
return err
}
e.ConvOut = convOut
// Quant conv
quantConv, err := newCausalConv3d(weights, "quant_conv")
if err != nil {
return err
}
e.QuantConv = quantConv
return nil
}
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
block, err := newResBlock(weights, prefix, inDim, outDim)
if err != nil {
return nil, err
}
return &EncoderResBlock{block}, nil
}
// newEncoderDownsample creates a downsample layer for the encoder
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
resample, err := newCausalConv3d(weights, prefix+".resample.1")
if err != nil {
return nil, err
}
var timeConv *CausalConv3d
if temporal {
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
}
return &EncoderDownsample{
Resample: resample,
TimeConv: timeConv,
}, nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor (channels-first)
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
mlx.Eval(x)
// Conv in
x = e.ConvIn.Forward(x)
// Encoder blocks (mix of ResBlocks and Downsamplers)
for _, block := range e.Blocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Mid block
x = e.MidBlock.Forward(x)
// Norm + silu
{
prev := x
x = e.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Conv out
{
prev := x
x = e.ConvOut.Forward(x)
prev.Free()
}
// Quant conv
{
prev := x
x = e.QuantConv.Forward(x)
prev.Free()
}
// Get mode from distribution (first half of channels = mean)
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
shape := x.Shape()
latentC := shape[4] / 2
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
// Convert back to channels-first [N, C, T, H, W]
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
mlx.Eval(x)
return x
}
// VAEDecoder is the decoder part of the VAE
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// loadFromWeights loads the decoder from pre-loaded weights
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
d.Config = cfg
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
d.PostQuantConv = postQuantConv
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
d.ConvIn = convIn
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
d.MidBlock = midBlock
// Up blocks (reversed dim_mult)
numUpBlocks := len(cfg.DimMult)
d.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
d.UpBlocks[i] = upBlock
}
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
d.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
d.ConvOut = convOut
return nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] denormalized latents
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Convert from channels-first to channels-last
{
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// PostQuantConv
x = d.PostQuantConv.Forward(z)
z.Free()
// ConvIn
{
prev := x
x = d.ConvIn.Forward(x)
prev.Free()
}
// Mid block
x = d.MidBlock.Forward(x)
// Up blocks
for _, upBlock := range d.UpBlocks {
x = upBlock.Forward(x)
}
// NormOut + silu
{
prev := x
x = d.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// ConvOut
{
prev := x
x = d.ConvOut.Forward(x)
prev.Free()
}
// Post-processing: clamp and convert back to channels-first
{
prev := x
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// DownBlock handles downsampling in encoder
type DownBlock struct {
ResBlocks []*ResBlock
Downsampler *Downsample
}
// newDownBlock creates a down block
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var downsampler *Downsample
if downsampleMode != "" {
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
}
return &DownBlock{
ResBlocks: resBlocks,
Downsampler: downsampler,
}, nil
}
// Forward applies down block
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range d.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if d.Downsampler != nil {
prev := x
x = d.Downsampler.Forward(x)
prev.Free()
}
return x
}
// Downsample handles spatial downsampling
type Downsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newDownsample creates a downsampler
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Downsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies downsampling to channels-last input [B, T, H, W, C]
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := d.Conv.Shape()[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
// For 3x3 stride 2: pad 1 on all sides
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv with stride 2 using manual strided patching
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
x = conv2DStrided(x, weight, 2)
if d.Bias != nil {
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
mlx.Eval(x)
return x
}

View File

@@ -9,7 +9,6 @@ import (
"encoding/json"
"flag"
"fmt"
"image"
"log/slog"
"net/http"
"os"
@@ -26,12 +25,11 @@ import (
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// Response is streamed back for each progress update
@@ -48,13 +46,6 @@ type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// ImageEditModel extends ImageModel with image editing/conditioning capability.
// Models that support input images for editing should implement this interface.
type ImageEditModel interface {
ImageModel
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
@@ -170,29 +161,6 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Validate and decode input images
const maxInputImages = 2
if len(req.Images) > maxInputImages {
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
return
}
var inputImages []image.Image
if len(req.Images) > 0 {
// TODO: add memory check for input images
inputImages = make([]image.Image, len(req.Images))
for i, imgBytes := range req.Images {
img, err := imagegen.DecodeImage(imgBytes)
if err != nil {
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
return
}
inputImages[i] = img
}
slog.Info("decoded input images", "count", len(inputImages))
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
@@ -224,19 +192,7 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
}
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
var img *mlx.Array
var err error
if len(inputImages) > 0 {
editModel, ok := s.model.(ImageEditModel)
if !ok {
http.Error(w, "model does not support image editing", http.StatusBadRequest)
return
}
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
} else {
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
}
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation

View File

@@ -226,27 +226,19 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
seed = time.Now().UnixNano()
}
// Extract raw image bytes from llm.ImageData slice
var images [][]byte
for _, img := range req.Images {
images = append(images, img.Data)
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"`
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}
body, err := json.Marshal(creq)