Compare commits

...

5 Commits

Author SHA1 Message Date
jmorganca
7c1a9bad6b cleanup 2026-01-18 16:30:15 -08:00
jmorganca
294d0d2b98 load image metadata 2026-01-18 14:23:58 -08:00
jmorganca
27bafe7e9f tiling 2026-01-18 14:00:14 -08:00
jmorganca
4adacca10e fast-ish 2026-01-18 01:51:49 -08:00
jmorganca
92c1d81f95 wip flux2 2026-01-17 22:42:14 -08:00
22 changed files with 3700 additions and 376 deletions

View File

@@ -7,12 +7,17 @@ import (
"encoding/json"
"flag"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
@@ -46,8 +51,8 @@ func main() {
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
@@ -61,6 +66,7 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
@@ -122,6 +128,44 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *flux2Flag:
m := &flux2.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
// Load input images with EXIF orientation correction
var loadedImages []image.Image
for _, path := range inputImages {
img, loadErr := loadImageWithEXIF(path)
if loadErr != nil {
log.Fatalf("Failed to load image %s: %v", path, loadErr)
}
loadedImages = append(loadedImages, img)
}
// When input images provided and user didn't override dimensions, use 0 to match input
fluxWidth := int32(*width)
fluxHeight := int32(*height)
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
// Both unset, will auto-detect from input
} else if len(loadedImages) > 0 && *width == 0 {
fluxWidth = 0 // Compute from height + aspect ratio
} else if len(loadedImages) > 0 && *height == 0 {
fluxHeight = 0 // Compute from width + aspect ratio
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
Prompt: *prompt,
Width: fluxWidth,
Height: fluxHeight,
Steps: *steps,
GuidanceScale: float32(*cfgScale),
Seed: *seed,
CapturePath: *gpuCapture,
InputImages: loadedImages,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
@@ -276,6 +320,8 @@ func detectModelKind(modelPath string) (string, error) {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
case "Flux2KleinPipeline":
return "flux2", nil
}
}
return "zimage", nil
@@ -296,3 +342,12 @@ func detectModelKind(modelPath string) (string, error) {
return cfg.ModelType, nil
}
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
func loadImageWithEXIF(path string) (image.Image, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
return imagegen.LoadImageFromBytes(data)
}

View File

@@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/meta"
"github.com/ollama/ollama/x/imagegen/mlx"
)
@@ -108,3 +109,13 @@ func clampF(v, min, max float32) float32 {
}
return v
}
// LoadImageFromBytes loads an image from bytes, applying EXIF orientation.
// Supports JPEG and PNG formats.
func LoadImageFromBytes(data []byte) (image.Image, error) {
img, _, err := meta.Decode(data)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return img, nil
}

View File

@@ -24,9 +24,8 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 20 * GB, // ~20GB for Flux
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
@@ -72,26 +71,38 @@ func ResolveModelName(modelName string) string {
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName)
if err != nil {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
className := DetectModelType(modelName)
if estimate, ok := modelVRAMEstimates[className]; ok {
return estimate
}
return 21 * GB
}
// DetectModelType reads model_index.json and returns the model type.
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return ""
}
var index struct {
Architecture string `json:"architecture"`
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return ""
}
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
if index.Architecture != "" {
return index.Architecture
}
return index.ClassName
}

View File

@@ -72,9 +72,8 @@ func TestCheckMemoryRequirements(t *testing.T) {
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB,
"QwenImagePipeline": 80 * GB,
"ZImagePipeline": 21 * GB,
"FluxPipeline": 20 * GB,
}
for name, expectedVRAM := range expected {

252
x/imagegen/meta/metadata.go Normal file
View File

@@ -0,0 +1,252 @@
// Package meta provides image metadata reading and transformation utilities.
package meta
import (
"bytes"
"encoding/binary"
"image"
_ "image/jpeg"
_ "image/png"
)
// Metadata contains image metadata extracted from raw bytes.
type Metadata struct {
Width int // Image width in pixels (after orientation correction)
Height int // Image height in pixels (after orientation correction)
Orientation int // EXIF orientation (1-8, 1=normal)
Format string // Image format ("jpeg", "png", etc.)
}
// Read extracts metadata from image bytes without fully decoding pixels.
// Returns nil if the format is not recognized.
func Read(data []byte) *Metadata {
if len(data) < 8 {
return nil
}
// Detect format and read metadata
if data[0] == 0xFF && data[1] == 0xD8 {
return readJPEGMetadata(data)
}
if string(data[:8]) == "\x89PNG\r\n\x1a\n" {
return readPNGMetadata(data)
}
return nil
}
// readJPEGMetadata reads metadata from JPEG bytes.
func readJPEGMetadata(data []byte) *Metadata {
m := &Metadata{
Format: "jpeg",
Orientation: 1,
}
r := bytes.NewReader(data[2:]) // Skip SOI marker
for {
var marker [2]byte
if _, err := r.Read(marker[:]); err != nil {
break
}
if marker[0] != 0xFF {
break
}
switch marker[1] {
case 0xE1: // APP1 (EXIF)
m.Orientation = parseAPP1(r)
case 0xC0, 0xC1, 0xC2: // SOF0, SOF1, SOF2 (Start of Frame)
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
break
}
var sof [5]byte
if _, err := r.Read(sof[:]); err != nil {
break
}
m.Height = int(binary.BigEndian.Uint16(sof[1:3]))
m.Width = int(binary.BigEndian.Uint16(sof[3:5]))
// Swap dimensions for 90°/270° rotations
if m.Orientation >= 5 {
m.Width, m.Height = m.Height, m.Width
}
return m
case 0xD9, 0xDA: // EOI or SOS - stop scanning
return m
default:
// Skip marker
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
continue // RST markers have no length
}
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
break
}
segLen := int(binary.BigEndian.Uint16(lenBytes[:])) - 2
if segLen > 0 {
r.Seek(int64(segLen), 1)
}
}
}
return m
}
// parseAPP1 parses an APP1 segment for EXIF orientation.
func parseAPP1(r *bytes.Reader) int {
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(binary.BigEndian.Uint16(lenBytes[:])) - 2
if segLen < 14 {
r.Seek(int64(segLen), 1)
return 1
}
data := make([]byte, segLen)
if _, err := r.Read(data); err != nil {
return 1
}
// Check for "Exif\0\0" header
if string(data[:4]) != "Exif" || data[4] != 0 || data[5] != 0 {
return 1
}
return parseTIFFOrientation(data[6:])
}
// parseTIFFOrientation extracts orientation from TIFF header.
func parseTIFFOrientation(tiff []byte) int {
if len(tiff) < 8 {
return 1
}
var byteOrder binary.ByteOrder
switch string(tiff[:2]) {
case "MM":
byteOrder = binary.BigEndian
case "II":
byteOrder = binary.LittleEndian
default:
return 1
}
if byteOrder.Uint16(tiff[2:4]) != 42 {
return 1
}
ifdOffset := byteOrder.Uint32(tiff[4:8])
if int(ifdOffset)+2 > len(tiff) {
return 1
}
numEntries := byteOrder.Uint16(tiff[ifdOffset : ifdOffset+2])
entryStart := ifdOffset + 2
for i := range int(numEntries) {
offset := entryStart + uint32(i)*12
if int(offset)+12 > len(tiff) {
break
}
if byteOrder.Uint16(tiff[offset:offset+2]) == 0x0112 {
orientation := int(byteOrder.Uint16(tiff[offset+8 : offset+10]))
if orientation >= 1 && orientation <= 8 {
return orientation
}
return 1
}
}
return 1
}
// readPNGMetadata reads metadata from PNG bytes.
func readPNGMetadata(data []byte) *Metadata {
m := &Metadata{
Format: "png",
Orientation: 1, // PNG has no EXIF orientation
}
// IHDR chunk starts at offset 8 (after signature)
// Structure: length(4) + type(4) + data + crc(4)
if len(data) < 24 {
return m
}
// Verify IHDR chunk type
if string(data[12:16]) != "IHDR" {
return m
}
// IHDR data: width(4) + height(4) + bit_depth(1) + ...
m.Width = int(binary.BigEndian.Uint32(data[16:20]))
m.Height = int(binary.BigEndian.Uint32(data[20:24]))
return m
}
// Decode decodes an image from bytes with EXIF orientation applied.
func Decode(data []byte) (image.Image, string, error) {
meta := Read(data)
orientation := 1
if meta != nil {
orientation = meta.Orientation
}
img, format, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, "", err
}
return ApplyOrientation(img, orientation), format, nil
}
// ApplyOrientation transforms an image according to EXIF orientation.
// Returns the original image if orientation is 1 (normal) or invalid.
func ApplyOrientation(img image.Image, orientation int) image.Image {
if orientation <= 1 || orientation > 8 {
return img
}
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
outW, outH := w, h
if orientation >= 5 {
outW, outH = h, w
}
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
for y := range h {
for x := range w {
var dx, dy int
switch orientation {
case 2: // Mirror horizontal
dx, dy = w-1-x, y
case 3: // Rotate 180
dx, dy = w-1-x, h-1-y
case 4: // Mirror vertical
dx, dy = x, h-1-y
case 5: // Mirror horizontal + rotate 270
dx, dy = y, x
case 6: // Rotate 90 CW
dx, dy = h-1-y, x
case 7: // Mirror horizontal + rotate 90
dx, dy = h-1-y, w-1-x
case 8: // Rotate 270 CW
dx, dy = y, w-1-x
}
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
}
}
return out
}

View File

@@ -0,0 +1,244 @@
package meta
import (
"bytes"
"encoding/binary"
"image"
"image/color"
"image/jpeg"
"image/png"
"testing"
)
func TestRead_JPEG(t *testing.T) {
// Create a simple JPEG
img := image.NewRGBA(image.Rect(0, 0, 100, 50))
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, nil); err != nil {
t.Fatal(err)
}
m := Read(buf.Bytes())
if m == nil {
t.Fatal("expected metadata, got nil")
}
if m.Format != "jpeg" {
t.Errorf("format = %q, want jpeg", m.Format)
}
if m.Width != 100 {
t.Errorf("width = %d, want 100", m.Width)
}
if m.Height != 50 {
t.Errorf("height = %d, want 50", m.Height)
}
if m.Orientation != 1 {
t.Errorf("orientation = %d, want 1", m.Orientation)
}
}
func TestRead_PNG(t *testing.T) {
img := image.NewRGBA(image.Rect(0, 0, 200, 100))
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatal(err)
}
m := Read(buf.Bytes())
if m == nil {
t.Fatal("expected metadata, got nil")
}
if m.Format != "png" {
t.Errorf("format = %q, want png", m.Format)
}
if m.Width != 200 {
t.Errorf("width = %d, want 200", m.Width)
}
if m.Height != 100 {
t.Errorf("height = %d, want 100", m.Height)
}
if m.Orientation != 1 {
t.Errorf("orientation = %d, want 1 (PNG has no EXIF)", m.Orientation)
}
}
func TestRead_JPEGWithEXIF(t *testing.T) {
tests := []struct {
name string
orientation int
wantW, wantH int // after orientation correction
}{
{"normal", 1, 100, 50},
{"mirror_h", 2, 100, 50},
{"rotate_180", 3, 100, 50},
{"mirror_v", 4, 100, 50},
{"mirror_h_rot270", 5, 50, 100}, // swapped
{"rotate_90", 6, 50, 100}, // swapped
{"mirror_h_rot90", 7, 50, 100}, // swapped
{"rotate_270", 8, 50, 100}, // swapped
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data := makeJPEGWithOrientation(100, 50, tt.orientation)
m := Read(data)
if m == nil {
t.Fatal("expected metadata")
}
if m.Orientation != tt.orientation {
t.Errorf("orientation = %d, want %d", m.Orientation, tt.orientation)
}
if m.Width != tt.wantW || m.Height != tt.wantH {
t.Errorf("size = %dx%d, want %dx%d", m.Width, m.Height, tt.wantW, tt.wantH)
}
})
}
}
func TestRead_Invalid(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{"empty", nil},
{"too_short", []byte{0xFF}},
{"random", []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := Read(tt.data)
if m != nil {
t.Errorf("expected nil for invalid data, got %+v", m)
}
})
}
}
func TestApplyOrientation(t *testing.T) {
// Create a 4x2 image with distinct pixels to verify transformations
// Layout: [R G B Y]
// [C M W K]
img := image.NewRGBA(image.Rect(0, 0, 4, 2))
colors := []color.RGBA{
{255, 0, 0, 255}, // R
{0, 255, 0, 255}, // G
{0, 0, 255, 255}, // B
{255, 255, 0, 255}, // Y
{0, 255, 255, 255}, // C
{255, 0, 255, 255}, // M
{255, 255, 255, 255}, // W
{0, 0, 0, 255}, // K
}
for i, c := range colors {
img.Set(i%4, i/4, c)
}
tests := []struct {
orientation int
wantW, wantH int
topLeft color.RGBA // what should be at (0,0) after transform
}{
{1, 4, 2, colors[0]}, // R - no change
{2, 4, 2, colors[3]}, // Y - mirror horizontal
{3, 4, 2, colors[7]}, // K - rotate 180
{4, 4, 2, colors[4]}, // C - mirror vertical
{5, 2, 4, colors[0]}, // R - transpose
{6, 2, 4, colors[4]}, // C - rotate 90 CW
{7, 2, 4, colors[7]}, // K - transverse
{8, 2, 4, colors[3]}, // Y - rotate 270 CW
}
for _, tt := range tests {
t.Run(string(rune('0'+tt.orientation)), func(t *testing.T) {
result := ApplyOrientation(img, tt.orientation)
bounds := result.Bounds()
if bounds.Dx() != tt.wantW || bounds.Dy() != tt.wantH {
t.Errorf("size = %dx%d, want %dx%d", bounds.Dx(), bounds.Dy(), tt.wantW, tt.wantH)
}
got := result.At(0, 0).(color.RGBA)
if got != tt.topLeft {
t.Errorf("top-left = %v, want %v", got, tt.topLeft)
}
})
}
}
func TestDecode(t *testing.T) {
// Test with orientation 6 (90° CW rotation)
data := makeJPEGWithOrientation(100, 50, 6)
img, format, err := Decode(data)
if err != nil {
t.Fatal(err)
}
if format != "jpeg" {
t.Errorf("format = %q, want jpeg", format)
}
bounds := img.Bounds()
// After 90° rotation, 100x50 becomes 50x100
if bounds.Dx() != 50 || bounds.Dy() != 100 {
t.Errorf("decoded size = %dx%d, want 50x100", bounds.Dx(), bounds.Dy())
}
}
// makeJPEGWithOrientation creates a minimal JPEG with EXIF orientation.
func makeJPEGWithOrientation(w, h, orientation int) []byte {
// Build EXIF APP1 segment with orientation
exif := buildEXIF(orientation)
// Create base JPEG
img := image.NewRGBA(image.Rect(0, 0, w, h))
var jpegBuf bytes.Buffer
jpeg.Encode(&jpegBuf, img, nil)
jpegData := jpegBuf.Bytes()
// Insert EXIF after SOI marker (first 2 bytes)
var result bytes.Buffer
result.Write(jpegData[:2]) // SOI
result.Write(exif) // APP1 with EXIF
result.Write(jpegData[2:]) // Rest of JPEG
return result.Bytes()
}
// buildEXIF creates a minimal EXIF APP1 segment with orientation tag.
func buildEXIF(orientation int) []byte {
var buf bytes.Buffer
// APP1 marker
buf.WriteByte(0xFF)
buf.WriteByte(0xE1)
// Build TIFF/EXIF data
var tiff bytes.Buffer
// TIFF header (little endian)
tiff.WriteString("II") // Little endian
binary.Write(&tiff, binary.LittleEndian, uint16(42)) // TIFF magic
binary.Write(&tiff, binary.LittleEndian, uint32(8)) // IFD0 offset
// IFD0
binary.Write(&tiff, binary.LittleEndian, uint16(1)) // 1 entry
// Orientation tag entry (12 bytes)
binary.Write(&tiff, binary.LittleEndian, uint16(0x0112)) // Tag: Orientation
binary.Write(&tiff, binary.LittleEndian, uint16(3)) // Type: SHORT
binary.Write(&tiff, binary.LittleEndian, uint32(1)) // Count: 1
binary.Write(&tiff, binary.LittleEndian, uint16(orientation))
binary.Write(&tiff, binary.LittleEndian, uint16(0)) // Padding
// Next IFD offset (0 = none)
binary.Write(&tiff, binary.LittleEndian, uint32(0))
// Build APP1 segment
exifData := append([]byte("Exif\x00\x00"), tiff.Bytes()...)
// Write length (includes length bytes but not marker)
binary.Write(&buf, binary.BigEndian, uint16(len(exifData)+2))
buf.Write(exifData)
return buf.Bytes()
}

View File

@@ -1137,6 +1137,27 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
return RMSNorm(x, ones, eps)
}
// LayerNorm applies layer normalization without learnable params
// (x - mean) / sqrt(var + eps)
func LayerNorm(x *Array, eps float32) *Array {
return LayerNormWithWeightBias(x, nil, nil, eps)
}
// LayerNormWithWeightBias computes layer normalization using mlx.fast
// weight and bias can be nil for elementwise_affine=False
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
res := C.mlx_array_new()
var wc, bc C.mlx_array
if weight != nil {
wc = weight.c
}
if bias != nil {
bc = bias.c
}
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
return newArray(res)
}
// RoPE applies rotary position embeddings using mlx.fast
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
res := C.mlx_array_new()

View File

@@ -0,0 +1,524 @@
//go:build mlx
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
// Klein is a 4B parameter distilled model that supports sub-second inference.
package flux2
import (
"context"
"encoding/json"
"fmt"
"image"
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"golang.org/x/image/draw"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 4 for Klein)
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
Seed int64 // Random seed
Progress imagegen.ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
InputImages []image.Image // Reference images for image conditioning (already loaded)
}
// Model represents a FLUX.2 Klein model.
type Model struct {
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *qwen3.TextEncoder
Transformer *Flux2Transformer2DModel
VAE *AutoencoderKLFlux2
SchedulerConfig *SchedulerConfig
}
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
var TextEncoderLayerIndices = []int{8, 17, 26}
// Load loads the FLUX.2 Klein model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &qwen3.TextEncoder{}
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
// Load transformer
m.Transformer = &Flux2Transformer2DModel{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
// Load VAE
m.VAE = &AutoencoderKLFlux2{}
if err := m.VAE.Load(manifest); err != nil {
return fmt.Errorf("VAE: %w", err)
}
// Evaluate all weights in a single batch (reduces GPU sync overhead)
fmt.Print(" Evaluating weights... ")
allWeights := mlx.Collect(m.TextEncoder)
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
mlx.Eval(allWeights...)
fmt.Println("✓")
// Load scheduler config
m.SchedulerConfig = DefaultSchedulerConfig()
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
}
}
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress imagegen.ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
const MaxRefPixels = 728 * 728
// generate is the internal denoising pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Enable MLX compilation for fused kernels
mlx.EnableCompile()
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 4 // Klein default: 4 steps for distilled model
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
}
// Determine output dimensions
if len(cfg.InputImages) > 0 {
// With input images, compute missing dimension from aspect ratio
// Images are already EXIF-rotated by the caller
bounds := cfg.InputImages[0].Bounds()
imgW, imgH := bounds.Dx(), bounds.Dy()
aspectRatio := float64(imgH) / float64(imgW)
if cfg.Width > 0 && cfg.Height <= 0 {
// Width specified, compute height
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
} else if cfg.Height > 0 && cfg.Width <= 0 {
// Height specified, compute width
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
} else if cfg.Width <= 0 && cfg.Height <= 0 {
// Neither specified, use input dimensions
cfg.Width = int32(imgW)
cfg.Height = int32(imgH)
}
}
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
pixels := int(cfg.Width) * int(cfg.Height)
if pixels > MaxOutputPixels {
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
}
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
tcfg := m.Transformer.TransformerConfig
patchSize := m.VAE.Config.PatchSize
// Latent dimensions: image / 8 (VAE downscale) / patch_size
latentH := cfg.Height / 8
latentW := cfg.Width / 8
patchH := latentH / patchSize[0]
patchW := latentW / patchSize[1]
imgSeqLen := patchH * patchW
// Text encoding with multi-layer extraction (no padding, use true sequence length)
fmt.Print(" Encoding prompt... ")
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
fmt.Println("✓")
// Encode reference images if provided
var refTokens *ImageCondTokens
var refHeights, refWidths []int32
if len(cfg.InputImages) > 0 {
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
var err error
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
if err != nil {
return nil, fmt.Errorf("encode reference images: %w", err)
}
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
limitPixels := MaxRefPixels
if len(cfg.InputImages) > 1 {
limitPixels = MaxRefPixels / 2
}
for _, img := range cfg.InputImages {
_, w, h := PrepareImage(img, limitPixels)
refHeights = append(refHeights, int32(h/16))
refWidths = append(refWidths, int32(w/16))
}
}
// Scheduler
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
latentChannels := m.VAE.Config.LatentChannels
packedChannels := latentChannels * 4 // 32 * 4 = 128
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
// This matches diffusers' _pack_latents
patches := packLatents(latents)
noiseSeqLen := patches.Shape()[1]
// RoPE cache - includes reference images if present
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
defer func() {
rope.Cos.Free()
rope.Sin.Free()
}()
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
timesteps := make([]*mlx.Array, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
}
// Evaluate setup arrays
fmt.Print(" Evaluating setup... ")
setupStart := time.Now()
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
toEval = append(toEval, timesteps...)
if refTokens != nil {
toEval = append(toEval, refTokens.Tokens)
}
mlx.Eval(toEval...)
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps)
}
loopStart := time.Now()
stepStart := time.Now()
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
timestep := timesteps[i]
// Prepare input - concatenate noise patches with reference tokens if present
imgInput := patches
if refTokens != nil {
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
}
// Transformer forward pass
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
// If we concatenated reference tokens, slice to only get noise portion
if refTokens != nil {
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
}
// Scheduler step (keep reference to old patches for the computation graph)
newPatches := scheduler.Step(output, patches, i)
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
mlx.Eval(newPatches)
patches = newPatches
elapsed := time.Since(stepStart).Seconds()
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
if i == 0 {
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
} else {
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
}
stepStart = time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
}
loopTime := time.Since(loopStart).Seconds()
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
// VAE decode with tiling for larger images
fmt.Print(" Decoding VAE... ")
vaeStart := time.Now()
// Enable tiling for images > 512x512 (latent > 64x64)
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
if patchH*2 > 64 || patchW*2 > 64 {
m.VAE.Tiling = DefaultTilingConfig()
}
decoded := m.VAE.Decode(patches, patchH, patchW)
mlx.Eval(decoded)
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
func packLatents(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
x = mlx.Reshape(x, B, C, H*W)
return mlx.Transpose(x, 0, 2, 1)
}
// LoadPersistent loads the model and keeps it in memory for repeated use.
func LoadPersistent(modelName string) (*Model, error) {
m := &Model{}
if err := m.Load(modelName); err != nil {
return nil, err
}
return m, nil
}
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
const ImageRefScale = 10
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
// Returns the processed image and its dimensions.
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Cap pixels if needed (like diffusers cap_pixels)
if limitPixels > 0 && w*h > limitPixels {
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
w = int(float64(w) * scale)
h = int(float64(h) * scale)
}
// Round down to multiple of 16
w = (w / 16) * 16
h = (h / 16) * 16
if w < 16 {
w = 16
}
if h < 16 {
h = 16
}
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
resized := image.NewRGBA(image.Rect(0, 0, w, h))
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
return resized, w, h
}
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
func ImageToTensor(img image.Image) *mlx.Array {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
data := make([]float32, 3*h*w)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
// RGBA returns 16-bit values, convert to [-1, 1]
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
}
}
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
return arr
}
// ImageCondTokens holds encoded reference image tokens.
type ImageCondTokens struct {
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
}
// EncodeImageRefs encodes reference images using the VAE.
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
if len(images) == 0 {
return nil, nil
}
// Limit reference images to reduce attention memory
limitPixels := MaxRefPixels
if len(images) > 1 {
limitPixels = MaxRefPixels / 2
}
var allTokens []*mlx.Array
for _, img := range images {
// Prepare image (resize, crop to multiple of 16)
prepared, prepW, prepH := PrepareImage(img, limitPixels)
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
// Convert to tensor [-1, 1]
tensor := ImageToTensor(prepared)
// Encode with VAE - returns [1, L, 128]
encoded := m.VAE.EncodeImage(tensor)
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
// Defer eval - will be done with other setup arrays
allTokens = append(allTokens, squeezed)
fmt.Println("✓")
}
// For single image, just add batch dimension directly
// For multiple images, concatenate first
var tokens *mlx.Array
if len(allTokens) == 1 {
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
} else {
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
}
return &ImageCondTokens{Tokens: tokens}, nil
}

View File

@@ -0,0 +1,224 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// RoPEConfig holds 4D RoPE configuration for Flux2
type RoPEConfig struct {
Theta int32 // 2000 for Klein
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
}
// RoPECache holds precomputed RoPE cos/sin values
type RoPECache struct {
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
TextLen int32 // Length of text sequence
ImageLen int32 // Length of image sequence
}
// PrepareTextIDs creates position IDs for text tokens.
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
// Returns: [seqLen, 4]
func PrepareTextIDs(seqLen int32) *mlx.Array {
ids := make([]float32, seqLen*4)
for i := int32(0); i < seqLen; i++ {
idx := i * 4
ids[idx+0] = 0 // T = 0
ids[idx+1] = 0 // H = 0
ids[idx+2] = 0 // W = 0
ids[idx+3] = float32(i) // L = sequence position
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareLatentIDs creates position IDs for image latent tokens.
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
// The latents are in row-major order (H then W).
// Returns: [height*width, 4]
func PrepareLatentIDs(height, width int32) *mlx.Array {
seqLen := height * width
ids := make([]float32, seqLen*4)
idx := 0
for h := int32(0); h < height; h++ {
for w := int32(0); w < width; w++ {
ids[idx*4+0] = 0 // T = 0
ids[idx*4+1] = float32(h) // H = row
ids[idx*4+2] = float32(w) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
// Returns: [total_tokens, 4]
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
// Calculate total tokens
totalTokens := int32(0)
for i := range imageHeights {
totalTokens += imageHeights[i] * imageWidths[i]
}
ids := make([]float32, totalTokens*4)
idx := int32(0)
for imgIdx, h := range imageHeights {
w := imageWidths[imgIdx]
tValue := float32(scale * int32(imgIdx+1))
for hi := int32(0); hi < h; hi++ {
for wi := int32(0); wi < w; wi++ {
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
ids[idx*4+1] = float32(hi) // H = row
ids[idx*4+2] = float32(wi) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
}
return mlx.NewArray(ids, []int32{totalTokens, 4})
}
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
// ids: [L, 4] with (T, H, W, L) coordinates
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
// theta: base frequency (2000 for Klein)
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
shape := ids.Shape()
seqLen := shape[0]
// Compute total head dim (sum of all axes dims)
headDim := int32(0)
for _, d := range axesDims {
headDim += d
}
// Extract each coordinate dimension
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
// Compute frequencies for each axis
logTheta := float32(math.Log(float64(theta)))
cosArrs := make([]*mlx.Array, 4)
sinArrs := make([]*mlx.Array, 4)
positions := []*mlx.Array{posT, posH, posW, posL}
for i, axisDim := range axesDims {
half := axisDim / 2
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
freqs := make([]float32, half)
for j := int32(0); j < half; j++ {
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
}
freqArr := mlx.NewArray(freqs, []int32{1, half})
// Compute pos * freq -> [L, half]
posExpanded := positions[i] // [L, 1]
args := mlx.Mul(posExpanded, freqArr) // [L, half]
// Compute cos and sin for this axis
cosAxis := mlx.Cos(args) // [L, half]
sinAxis := mlx.Sin(args) // [L, half]
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
sinAxis = mlx.ExpandDims(sinAxis, 2)
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
cosArrs[i] = cosAxis
sinArrs[i] = sinAxis
}
// Concatenate all axes: [L, headDim]
cos := mlx.Concatenate(cosArrs, 1)
sin := mlx.Concatenate(sinArrs, 1)
// Reshape to [1, L, 1, headDim] for broadcasting with attention
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
return cos, sin
}
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
// x: [B, L, nheads, head_dim]
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
// Returns: x with RoPE applied
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
// Extract real (index 0) and imag (index 1) parts
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
negXImag := mlx.Neg(xImag)
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
}
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
// textLen: number of text tokens
// noiseH, noiseW: dimensions of the noise latent in patch tokens
// axesDims: [32, 32, 32, 32]
// theta: 2000
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
// scale: time coordinate offset between reference images (e.g., 10)
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
textIDs := PrepareTextIDs(textLen)
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
var allIDs *mlx.Array
imageLen := noiseH * noiseW
if len(refHeights) > 0 {
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
for i := range refHeights {
imageLen += refHeights[i] * refWidths[i]
}
} else {
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
}
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
cos = mlx.ToBFloat16(cos)
sin = mlx.ToBFloat16(sin)
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
}

View File

@@ -0,0 +1,149 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds Flow-Match scheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0 for Klein
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
}
// DefaultSchedulerConfig returns default config for Klein
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0, // Klein uses 3.0
UseDynamicShifting: true,
TimeShiftType: "exponential",
}
}
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
// Then applies time shift, appends 0.0 at end
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
// linspace(1, 1/num_steps, num_steps)
var sigma float32
if numSteps == 1 {
sigma = 1.0
} else {
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
}
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
sigma = s.timeShift(mu, sigma)
} else {
// If not dynamic shifting, apply fixed shift scaling like diffusers
shift := s.Config.Shift
sigma = shift * sigma / (1 + (shift-1)*sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal zero
s.Sigmas[numSteps] = 0.0
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
s.Timesteps = make([]float32, numSteps+1)
for i, v := range s.Sigmas {
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
}
}
// timeShift applies the dynamic time shift
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if s.Config.TimeShiftType == "linear" {
return mu / (mu + (1.0/t-1.0))
}
// Default: exponential
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 for precision (matches diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(outputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to bfloat16
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// CalculateShift computes the mu shift value for dynamic scheduling
// Matches diffusers compute_empirical_mu function
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
a2, b2 := float32(0.00016927), float32(0.45666666)
seqLen := float32(imgSeqLen)
if imgSeqLen > 4300 {
return a2*seqLen + b2
}
m200 := a2*seqLen + b2
m10 := a1*seqLen + b1
a := (m200 - m10) / 190.0
b := m200 - 200.0*a
return a*float32(numSteps) + b
}

View File

@@ -0,0 +1,562 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
func (c *TransformerConfig) InnerDim() int32 {
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
}
func (c *TransformerConfig) MLPHiddenDim() int32 {
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
}
// TimestepEmbedder creates timestep embeddings
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"linear_1"`
Linear2 nn.LinearLayer `weight:"linear_2"`
EmbedDim int32 // 256
}
// Forward creates sinusoidal embeddings and projects them
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
half := t.EmbedDim / 2
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// timesteps: [B] -> [B, 1]
tExpanded := mlx.ExpandDims(timesteps, 1)
// args: [B, half]
args := mlx.Mul(tExpanded, freqsArr)
// [cos(args), sin(args)] -> [B, embed_dim]
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
// MLP: linear_1 -> silu -> linear_2
h := t.Linear1.Forward(sinEmbed)
h = mlx.SiLU(h)
return t.Linear2.Forward(h)
}
// TimeGuidanceEmbed wraps the timestep embedder
// Weight names: time_guidance_embed.timestep_embedder.*
type TimeGuidanceEmbed struct {
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
}
// Forward computes timestep embeddings
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
return t.TimestepEmbedder.Forward(timesteps)
}
// Modulation computes adaptive modulation parameters
// Weight names: double_stream_modulation_img.linear.weight, etc.
type Modulation struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes modulation parameters
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
h := mlx.SiLU(temb)
return m.Linear.Forward(h)
}
// TransformerBlockAttn implements dual-stream attention
// Weight names: transformer_blocks.N.attn.*
type TransformerBlockAttn struct {
// Image stream (separate Q, K, V projections)
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
// Note: to_out has .0 suffix in weights, handled specially
ToOut0 nn.LinearLayer `weight:"to_out.0"`
// Text stream (add_ projections)
AddQProj nn.LinearLayer `weight:"add_q_proj"`
AddKProj nn.LinearLayer `weight:"add_k_proj"`
AddVProj nn.LinearLayer `weight:"add_v_proj"`
ToAddOut nn.LinearLayer `weight:"to_add_out"`
// QK norms for image stream
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
// QK norms for text stream (added)
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
}
// FeedForward implements SwiGLU MLP
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
type FeedForward struct {
LinearIn nn.LinearLayer `weight:"linear_in"`
LinearOut nn.LinearLayer `weight:"linear_out"`
}
// Forward applies SwiGLU MLP
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// LinearIn outputs 2x hidden dim for SwiGLU
h := ff.LinearIn.Forward(x)
shape := h.Shape()
half := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
h = mlx.Mul(mlx.SiLU(gate), up)
return ff.LinearOut.Forward(h)
}
// TransformerBlock implements a dual-stream transformer block
// Weight names: transformer_blocks.N.*
type TransformerBlock struct {
Attn *TransformerBlockAttn `weight:"attn"`
FF *FeedForward `weight:"ff"`
FFContext *FeedForward `weight:"ff_context"`
// Config (set after loading)
NHeads int32
HeadDim int32
Scale float32
}
// Forward applies the dual-stream block
// imgHidden: [B, imgLen, dim]
// txtHidden: [B, txtLen, dim]
// imgMod, txtMod: modulation params [B, 6*dim] each
// cos, sin: RoPE values
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := imgHidden.Shape()
B := imgShape[0]
imgLen := imgShape[1]
dim := imgShape[2]
txtLen := txtHidden.Shape()[1]
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
// === Attention branch ===
// Modulate inputs
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
// Compute Q, K, V for image stream (separate projections)
imgQ := block.Attn.ToQ.Forward(imgNorm)
imgK := block.Attn.ToK.Forward(imgNorm)
imgV := block.Attn.ToV.Forward(imgNorm)
// Compute Q, K, V for text stream (add_ projections)
txtQ := block.Attn.AddQProj.Forward(txtNorm)
txtK := block.Attn.AddKProj.Forward(txtNorm)
txtV := block.Attn.AddVProj.Forward(txtNorm)
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
// Apply QK norm (RMSNorm with learned scale)
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
imgK = applyQKNorm(imgK, block.Attn.NormK)
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
// Concatenate for joint attention: text first, then image
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA: [B, nheads, L, headDim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back: [B, L, nheads, headDim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Split back into txt and img
totalLen := txtLen + imgLen
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
// Reshape and project
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
txtOut = block.Attn.ToAddOut.Forward(txtOut)
imgOut = block.Attn.ToOut0.Forward(imgOut)
// Apply gates and residual
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
// === MLP branch ===
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
imgFFOut := block.FF.Forward(imgNorm)
txtFFOut := block.FFContext.Forward(txtNorm)
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
return imgHidden, txtHidden
}
// SingleTransformerBlockAttn implements attention for single-stream blocks
// Weight names: single_transformer_blocks.N.attn.*
type SingleTransformerBlockAttn struct {
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
}
// SingleTransformerBlock implements a single-stream transformer block
// Weight names: single_transformer_blocks.N.*
type SingleTransformerBlock struct {
Attn *SingleTransformerBlockAttn `weight:"attn"`
// Config
NHeads int32
HeadDim int32
InnerDim int32
MLPHidDim int32
Scale float32
}
// Forward applies the single-stream block
// x: [B, L, dim] concatenated text+image
// mod: modulation [B, 3*dim]
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
dim := shape[2]
// Parse modulation: (shift, scale, gate)
shift, scale, gate := parseModulation3(mod, dim, 0)
// Modulate input
h := modulateLayerNorm(x, shift, scale)
// Fused projection: QKV + MLP gate/up
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
// Split: first 3*dim is QKV, rest is MLP
qkvDim := 3 * block.InnerDim
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
// Split QKV
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
// Reshape for attention
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
// QK norm
q = applyQKNorm(q, block.Attn.NormQ)
k = applyQKNorm(k, block.Attn.NormK)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back and reshape
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
// MLP: SwiGLU
mlpShape := mlpIn.Shape()
half := mlpShape[2] / 2
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
// Concatenate attention and MLP for fused output
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
// Output projection
out := block.Attn.ToOut.Forward(combined)
// Apply gate and residual
return mlx.Add(x, mlx.Mul(gate, out))
}
// NormOut implements the output normalization with modulation
// Weight names: norm_out.linear.weight
type NormOut struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes final modulated output
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
dim := shape[2]
// Modulation: temb -> silu -> linear -> [shift, scale]
mod := mlx.SiLU(temb)
mod = n.Linear.Forward(mod)
// Split into scale and shift (diffusers order: scale first, shift second)
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
// Modulate with RMSNorm
return modulateLayerNorm(x, shift, scale)
}
// Flux2Transformer2DModel is the main Flux2 transformer
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
type Flux2Transformer2DModel struct {
// Timestep embedding
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
// Shared modulation
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
// Embedders
XEmbedder nn.LinearLayer `weight:"x_embedder"`
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
// Transformer blocks
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
// Output
NormOut *NormOut `weight:"norm_out"`
ProjOut nn.LinearLayer `weight:"proj_out"`
*TransformerConfig
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
// Initialize slices
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
// Initialize TimeGuidanceEmbed with embed dim
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Flux2Transformer2DModel) initComputedFields() {
cfg := m.TransformerConfig
innerDim := cfg.InnerDim()
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
// Initialize transformer blocks
for _, block := range m.TransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.Scale = scale
}
// Initialize single transformer blocks
for _, block := range m.SingleTransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.InnerDim = innerDim
block.MLPHidDim = cfg.MLPHiddenDim()
block.Scale = scale
}
}
// Forward runs the Flux2 transformer
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
patchShape := patches.Shape()
B := patchShape[0]
imgLen := patchShape[1]
txtLen := txtEmbeds.Shape()[1]
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
// Compute timestep embedding
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
// Embed patches and text
imgHidden := m.XEmbedder.Forward(patches)
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
// Compute shared modulation
imgMod := m.DoubleStreamModulationImg.Forward(temb)
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
singleMod := m.SingleStreamModulation.Forward(temb)
// Double (dual-stream) blocks
for _, block := range m.TransformerBlocks {
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
}
// Concatenate for single-stream: text first, then image
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
// Single-stream blocks
for _, block := range m.SingleTransformerBlocks {
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
}
// Extract image portion
totalLen := txtLen + imgLen
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
// Final norm and projection
imgOut = m.NormOut.Forward(imgOut, temb)
return m.ProjOut.Forward(imgOut)
}
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
// See applyQKNorm function below
// compiledSwiGLU fuses: silu(gate) * up
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
var compiledSwiGLU *mlx.CompiledFunc
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
gate, up := inputs[0], inputs[1]
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
}, true)
}
return compiledSwiGLU
}
// Helper functions
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
B := mod.Shape()[0]
start := offset * dim
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
// Expand for broadcasting [B, dim] -> [B, 1, dim]
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
gate = mlx.ExpandDims(gate, 1)
return shift, scale, gate
}
// modulateLayerNorm applies LayerNorm then shift/scale modulation
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
// Fast LayerNorm without learnable params
x = mlx.LayerNorm(x, 1e-6)
// Modulate: x * (1 + scale) + shift
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
return mlx.Add(x, shift)
}
// splitQKV splits a fused QKV tensor into Q, K, V
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
return q, k, v
}
// applyQKNorm applies RMSNorm with learned scale (no bias)
// Uses the optimized mlx_fast_rms_norm
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
return mlx.RMSNorm(x, scale, 1e-6)
}

View File

@@ -0,0 +1,881 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
type BatchNorm2D struct {
RunningMean *mlx.Array // [C]
RunningVar *mlx.Array // [C]
Weight *mlx.Array // [C] gamma
Bias *mlx.Array // [C] beta
Eps float32
Momentum float32
}
// Forward applies batch normalization (inference mode - uses running stats)
// Input and output are in NHWC format [B, H, W, C]
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
// Scale and shift (only if affine=True)
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Denormalize inverts the batch normalization
// Used when decoding latents
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Inverse: first undo affine, then undo normalization
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
x = mlx.Sub(x, bias)
}
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
x = mlx.Add(x, mean)
return x
}
// GroupNormLayer implements group normalization
// Reused from zimage package pattern
type GroupNormLayer struct {
Weight *mlx.Array
Bias *mlx.Array
NumGroups int32
Eps float32
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer (reused pattern)
type Conv2D struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
}
// Forward applies convolution (NHWC format)
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer `weight:"norm1"`
Conv1 *Conv2D `weight:"conv1"`
Norm2 *GroupNormLayer `weight:"norm2"`
Conv2 *Conv2D `weight:"conv2"`
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"` // nil if not present
}
// Forward applies the ResNet block
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
if rb.ConvShortcut != nil {
x = rb.ConvShortcut.Forward(x)
}
return mlx.Add(h, x)
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer
ToQWeight *mlx.Array
ToQBias *mlx.Array
ToKWeight *mlx.Array
ToKBias *mlx.Array
ToVWeight *mlx.Array
ToVBias *mlx.Array
ToOutWeight *mlx.Array
ToOutBias *mlx.Array
NumHeads int32
}
// Forward applies attention (NHWC format)
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
h := ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
v := mlx.Linear(h, ab.ToVWeight)
v = mlx.Add(v, ab.ToVBias)
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Add(out, residual)
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// Forward applies the up decoder block
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
x = resnet.Forward(x)
}
if ub.Upsample != nil {
x = upsample2x(x)
x = ub.Upsample.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
x = mlx.Take(x, hIdx, 1)
x = mlx.Take(x, wIdx, 2)
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// Forward applies the mid block
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
x = mb.Resnet1.Forward(x)
x = mb.Attention.Forward(x)
x = mb.Resnet2.Forward(x)
return x
}
// DefaultTilingConfig returns reasonable defaults for tiled decoding
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *vae.TilingConfig {
return vae.DefaultTilingConfig()
}
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
type AutoencoderKLFlux2 struct {
Config *VAEConfig
// Encoder components (for image editing)
EncoderConvIn *Conv2D
EncoderMid *VAEMidBlock
EncoderDown []*DownEncoderBlock2D
EncoderNormOut *GroupNormLayer
EncoderConvOut *Conv2D
// Decoder components
DecoderConvIn *Conv2D
DecoderMid *VAEMidBlock
DecoderUp []*UpDecoderBlock2D
DecoderNormOut *GroupNormLayer
DecoderConvOut *Conv2D
// Quant conv layers
QuantConv *Conv2D
PostQuantConv *Conv2D
// BatchNorm for latent normalization
LatentBN *BatchNorm2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// DownEncoderBlock2D implements a downsampling encoder block
type DownEncoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Downsample *Conv2D
}
// Forward applies the down encoder block
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range db.ResnetBlocks {
x = resnet.Forward(x)
}
if db.Downsample != nil {
// Pad then conv with stride 2
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
x = db.Downsample.Forward(x)
}
return x
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder components (for image conditioning)
if err := m.loadEncoderWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
// Load decoder conv_in
convInW, convInB, err := safetensors.LoadConv2D(weights, "decoder.conv_in")
if err != nil {
return fmt.Errorf("decoder.conv_in: %w", err)
}
m.DecoderConvIn = &Conv2D{Weight: convInW, Bias: convInB, Stride: 1, Padding: 1}
// Load mid block
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("decoder.mid_block: %w", err)
}
// Load up blocks
numBlocks := len(cfg.BlockOutChannels)
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load decoder conv_norm_out and conv_out
normW, normB, err := safetensors.LoadGroupNorm(weights, "decoder.conv_norm_out")
if err != nil {
return fmt.Errorf("decoder.conv_norm_out: %w", err)
}
m.DecoderNormOut = &GroupNormLayer{Weight: normW, Bias: normB, NumGroups: cfg.NormNumGroups, Eps: 1e-5}
convOutW, convOutB, err := safetensors.LoadConv2D(weights, "decoder.conv_out")
if err != nil {
return fmt.Errorf("decoder.conv_out: %w", err)
}
m.DecoderConvOut = &Conv2D{Weight: convOutW, Bias: convOutB, Stride: 1, Padding: 1}
// Load post_quant_conv
if cfg.UsePostQuantConv {
pqW, pqB, err := safetensors.LoadConv2D(weights, "post_quant_conv")
if err != nil {
return fmt.Errorf("post_quant_conv: %w", err)
}
m.PostQuantConv = &Conv2D{Weight: pqW, Bias: pqB, Stride: 1, Padding: 0}
}
// Load latent BatchNorm (affine=False, so no weight/bias)
bnMean, err := weights.GetTensor("bn.running_mean")
if err != nil {
return fmt.Errorf("bn.running_mean: %w", err)
}
bnVar, err := weights.GetTensor("bn.running_var")
if err != nil {
return fmt.Errorf("bn.running_var: %w", err)
}
m.LatentBN = &BatchNorm2D{
RunningMean: bnMean,
RunningVar: bnVar,
Weight: nil, // affine=False
Bias: nil, // affine=False
Eps: cfg.BatchNormEps,
Momentum: cfg.BatchNormMomentum,
}
fmt.Println("✓")
return nil
}
// loadVAEMidBlock loads the mid block
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// loadResnetBlock2D loads a ResNet block using safetensors helpers.
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1W, norm1B, err := safetensors.LoadGroupNorm(weights, prefix+".norm1")
if err != nil {
return nil, err
}
conv1W, conv1B, err := safetensors.LoadConv2D(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2W, norm2B, err := safetensors.LoadGroupNorm(weights, prefix+".norm2")
if err != nil {
return nil, err
}
conv2W, conv2B, err := safetensors.LoadConv2D(weights, prefix+".conv2")
if err != nil {
return nil, err
}
block := &ResnetBlock2D{
Norm1: &GroupNormLayer{Weight: norm1W, Bias: norm1B, NumGroups: numGroups, Eps: 1e-5},
Conv1: &Conv2D{Weight: conv1W, Bias: conv1B, Stride: 1, Padding: 1},
Norm2: &GroupNormLayer{Weight: norm2W, Bias: norm2B, NumGroups: numGroups, Eps: 1e-5},
Conv2: &Conv2D{Weight: conv2W, Bias: conv2B, Stride: 1, Padding: 1},
}
// ConvShortcut is optional
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
shortcutW, shortcutB, err := safetensors.LoadConv2D(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
block.ConvShortcut = &Conv2D{Weight: shortcutW, Bias: shortcutB, Stride: 1, Padding: 0}
}
return block, nil
}
// loadVAEAttentionBlock loads an attention block
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
}
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
if err != nil {
return nil, err
}
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
if err != nil {
return nil, err
}
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
if err != nil {
return nil, err
}
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
if err != nil {
return nil, err
}
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
if err != nil {
return nil, err
}
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
if err != nil {
return nil, err
}
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
if err != nil {
return nil, err
}
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
if err != nil {
return nil, err
}
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
if err != nil {
return nil, err
}
return &VAEAttentionBlock{
GroupNorm: &GroupNormLayer{Weight: normWeight, Bias: normBias, NumGroups: numGroups, Eps: 1e-5},
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
ToQBias: toQBias,
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
ToKBias: toKBias,
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
ToVBias: toVBias,
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
ToOutBias: toOutBias,
NumHeads: 1,
}, nil
}
// loadUpDecoderBlock2D loads an up decoder block
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upW, upB, err := safetensors.LoadConv2D(weights, prefix+".upsamplers.0.conv")
if err != nil {
return nil, err
}
upsample = &Conv2D{Weight: upW, Bias: upB, Stride: 1, Padding: 1}
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
// This is the inverse of the VAE's patchify for feeding to transformer
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
pH := H / patchH
pW := W / patchW
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
}
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
H := pH * patchH
W := pW * patchW
return mlx.Reshape(x, B, C, H, W)
}
// denormalizePatchified applies inverse batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] denormalized
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
x = mlx.Sub(x, bias)
}
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
x = mlx.Add(x, mean)
return x
}
// Decode decodes latent patches to images.
// If Tiling is set, uses tiled decoding to reduce memory for large images.
// latents: [B, L, C*4] patchified latents from transformer
// pH, pW: patch grid dimensions
// Returns: [B, 3, H, W] image tensor
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
// Denormalize patchified latents
z := v.denormalizePatchified(latents)
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
// Convert NCHW -> NHWC for processing
z = mlx.Transpose(z, 0, 2, 3, 1)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode (no tiling)
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
h = mlx.Transpose(h, 0, 3, 1, 2)
return h
}
// decodeTile decodes a single latent tile to pixels (internal helper)
// z: [B, H, W, C] latent tile in NHWC format
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
// Post-quant conv
if vae.PostQuantConv != nil {
z = vae.PostQuantConv.Forward(z)
}
// Decoder
h := vae.DecoderConvIn.Forward(z)
h = vae.DecoderMid.Forward(h)
for _, upBlock := range vae.DecoderUp {
h = upBlock.Forward(h)
}
h = vae.DecoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.DecoderConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
return h
}
// loadEncoderWeights loads the encoder components for image conditioning
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
// Load encoder conv_in
convInW, convInB, err := safetensors.LoadConv2D(weights, "encoder.conv_in")
if err != nil {
return fmt.Errorf("encoder.conv_in: %w", err)
}
m.EncoderConvIn = &Conv2D{Weight: convInW, Bias: convInB, Stride: 1, Padding: 1}
// Load encoder down blocks
numBlocks := len(cfg.BlockOutChannels)
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
hasDownsample := i < numBlocks-1
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load encoder mid block
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("encoder.mid_block: %w", err)
}
// Load encoder conv_norm_out and conv_out
normW, normB, err := safetensors.LoadGroupNorm(weights, "encoder.conv_norm_out")
if err != nil {
return fmt.Errorf("encoder.conv_norm_out: %w", err)
}
m.EncoderNormOut = &GroupNormLayer{Weight: normW, Bias: normB, NumGroups: cfg.NormNumGroups, Eps: 1e-5}
convOutW, convOutB, err := safetensors.LoadConv2D(weights, "encoder.conv_out")
if err != nil {
return fmt.Errorf("encoder.conv_out: %w", err)
}
m.EncoderConvOut = &Conv2D{Weight: convOutW, Bias: convOutB, Stride: 1, Padding: 1}
// Load quant_conv (for encoding)
if cfg.UseQuantConv {
qW, qB, err := safetensors.LoadConv2D(weights, "quant_conv")
if err != nil {
return fmt.Errorf("quant_conv: %w", err)
}
m.QuantConv = &Conv2D{Weight: qW, Bias: qB, Stride: 1, Padding: 0}
}
return nil
}
// loadDownEncoderBlock2D loads a down encoder block
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var downsample *Conv2D
if hasDownsample {
downW, downB, err := safetensors.LoadConv2D(weights, prefix+".downsamplers.0.conv")
if err != nil {
return nil, err
}
downsample = &Conv2D{Weight: downW, Bias: downB, Stride: 2, Padding: 0}
}
return &DownEncoderBlock2D{
ResnetBlocks: resnets,
Downsample: downsample,
}, nil
}
// EncodeImage encodes an image to normalized latents.
// image: [B, 3, H, W] image tensor in [-1, 1]
// Returns: [B, L, C*4] patchified normalized latents
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
// Convert NCHW -> NHWC
x := mlx.Transpose(image, 0, 2, 3, 1)
// Encoder
h := vae.EncoderConvIn.Forward(x)
for _, downBlock := range vae.EncoderDown {
h = downBlock.Forward(h)
}
h = vae.EncoderMid.Forward(h)
h = vae.EncoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.EncoderConvOut.Forward(h)
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
if vae.QuantConv != nil {
h = vae.QuantConv.Forward(h)
}
// Take only the mean (first latent_channels) - deterministic encoding
// h is [B, H, W, 64] -> take first 32 channels for mean
shape := h.Shape()
latentChannels := vae.Config.LatentChannels // 32
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
// Convert NHWC -> NCHW for patchifying
h = mlx.Transpose(h, 0, 3, 1, 2)
// Patchify: [B, C, H, W] -> [B, L, C*4]
h = vae.Patchify(h)
// Apply BatchNorm on patchified latents [B, L, 128]
// The BatchNorm has 128 channels matching the patchified dimension
h = vae.normalizePatchified(h)
return h
}
// normalizePatchified applies batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] normalized
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
// Scale and shift (only if affine=True)
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}

View File

@@ -0,0 +1,390 @@
//go:build mlx
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
package qwen3
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds Qwen3 text encoder configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Attention implements Qwen3 attention with QK norms
type Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking and optional padding mask
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// MLP implements Qwen3 SwiGLU MLP
type MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Block represents a single Qwen3 transformer block
type Block struct {
Attention *Attention `weight:"self_attn"`
MLP *MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h, mask, maskMode)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// TextEncoder is the full Qwen3 encoder
type TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
// This is used by Flux2 which needs embeddings from specific intermediate layers.
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
outputs := make([]*mlx.Array, len(layerIndices))
layerSet := make(map[int]int)
for i, idx := range layerIndices {
layerSet[idx] = i
}
for i, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
if outIdx, ok := layerSet[i]; ok {
outputs[outIdx] = h
}
}
return outputs
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
// If think is true, adds the <think></think> block after the assistant tag
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
func ApplyChatTemplate(prompt string, think bool) string {
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
if think {
return base + "<think>\n\n</think>\n\n"
}
return base
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
// If think is true, includes the <think></think> block in the chat template.
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
L := int32(maxLen)
validLen := int32(len(tokens))
combinedMaskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
combinedMaskData[idx] = 0
} else {
combinedMaskData[idx] = negInf
}
}
}
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
embeddings := te.Forward(tokensArr, maskMat, "")
return embeddings, maskArr
}
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
// If think is true, includes the <think></think> block in the chat template.
// Returns embeddings and padded sequence length.
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
// Pad to maxLen
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
padded := make([]int32, maxLen)
copy(padded, tokens)
for i := len(tokens); i < maxLen; i++ {
padded[i] = padToken
}
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
// This combines causal masking with PAD token masking
L := int32(maxLen)
validLen := int32(len(tokens))
maskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
maskData[idx] = 0 // allowed: causal OK and not PAD
} else {
maskData[idx] = negInf // blocked: future or PAD
}
}
}
maskMat := mlx.NewArray(maskData, []int32{L, L})
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
// Concatenate layer outputs along the hidden dimension
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
embeddings := mlx.Concatenate(layerOutputs, 2)
// Return embeddings and padded length
return embeddings, int32(maxLen)
}

View File

@@ -9,6 +9,7 @@ import (
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -17,13 +18,13 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress imagegen.ProgressFunc // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
@@ -31,9 +32,6 @@ type GenerateConfig struct {
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
@@ -117,7 +115,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress imagegen.ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
@@ -129,7 +127,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress imagegen.ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -18,18 +19,15 @@ import (
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress imagegen.ProgressFunc // Optional progress callback
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string

View File

@@ -3,287 +3,17 @@
package zimage
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
)
// Qwen3Config holds Qwen3 text encoder configuration
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
Attention *Qwen3Attention `weight:"self_attn"`
MLP *Qwen3MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Qwen3Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Qwen3Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Qwen3Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = &cfg
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Qwen3TextEncoder) initComputedFields() {
cfg := m.Qwen3Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// Re-export types from shared qwen3 package for backwards compatibility
type (
Qwen3Config = qwen3.Config
Qwen3Attention = qwen3.Attention
Qwen3MLP = qwen3.MLP
Qwen3Block = qwen3.Block
Qwen3TextEncoder = qwen3.TextEncoder
)
// ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string {
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
embeddings := te.Forward(tokensArr)
return embeddings, maskArr
}
var ApplyChatTemplate = qwen3.ApplyChatTemplate

View File

@@ -17,14 +17,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress imagegen.ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -34,9 +34,6 @@ type GenerateConfig struct {
FusedQKV bool // Enable fused QKV projection (default: false)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelName string
@@ -93,7 +90,7 @@ func (m *Model) Load(modelName string) error {
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
@@ -139,7 +136,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress imagegen.ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
@@ -151,7 +148,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress imagegen.ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
@@ -179,9 +176,16 @@ func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*m
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// generate is the internal denoising pipeline.
@@ -222,9 +226,9 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
}
// Pad both to same length (multiple of 32)

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
@@ -40,10 +41,15 @@ type Response struct {
Total int `json:"total,omitempty"`
}
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
model ImageModel
modelName string
}
@@ -80,10 +86,25 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(*modelName)
slog.Info("detected model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
}
server := &Server{
@@ -159,26 +180,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image
// Generate image using the common interface
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Step: step,
Total: total,
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation

View File

@@ -194,6 +194,42 @@ func joinPath(prefix, suffix string) string {
return prefix + "." + suffix
}
// LoadConv2D loads a Conv2D layer from weights.
// Returns weight in OHWI format (transposed from PyTorch's OIHW) and optional bias.
func LoadConv2D(weights WeightSource, path string) (*mlx.Array, *mlx.Array, error) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Transpose weight from OIHW to OHWI for MLX
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return weightOHWI, bias, nil
}
// LoadGroupNorm loads a GroupNormLayer from weights.
func LoadGroupNorm(weights WeightSource, path string) (*mlx.Array, *mlx.Array, error) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
bias, err := weights.GetTensor(path + ".bias")
if err != nil {
return nil, nil, fmt.Errorf("failed to load bias %s: %w", path, err)
}
return weight, bias, nil
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {

View File

@@ -510,7 +510,11 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
t.vocab.Merges[merge] = i
}
// Add special tokens to vocabulary
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
@@ -518,9 +522,7 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
if tok.Special {
t.specialTokens[tok.Content] = tok.ID
}
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Load special token configuration from companion files

4
x/imagegen/types.go Normal file
View File

@@ -0,0 +1,4 @@
package imagegen
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)

215
x/imagegen/vae/tiling.go Normal file
View File

@@ -0,0 +1,215 @@
//go:build mlx
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
package vae
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TilingConfig holds configuration for tiled VAE decoding.
// This is a general technique to reduce memory usage when decoding large latents.
type TilingConfig struct {
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
}
// DefaultTilingConfig returns reasonable defaults matching diffusers.
// tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *TilingConfig {
return &TilingConfig{
TileSize: 64, // 64 latent pixels
Overlap: 16, // 25% overlap
}
}
// decodedTile holds a decoded tile's pixel data and dimensions
type decodedTile struct {
data []float32
height int32
width int32
}
// DecodeTiled decodes latents using tiled processing with overlap blending.
// This reduces memory usage for large images by processing in overlapping tiles.
//
// Parameters:
// - latents: [1, H, W, C] latent tensor in NHWC format
// - cfg: tiling configuration (tile size and overlap)
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
//
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
shape := latents.Shape()
H := shape[1] // latent height
W := shape[2] // latent width
C := shape[3]
tileLatentSize := cfg.TileSize
overlapLatent := cfg.Overlap
// If image is small enough, just decode normally
if H <= tileLatentSize && W <= tileLatentSize {
decoded := decoder(latents)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
return decoded
}
// Calculate tiling parameters (matching diffusers)
overlapSize := tileLatentSize - overlapLatent // stride in latent space
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
// For other scale factors, this could be made configurable
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
blendExtent := overlapLatent * 8 // blend region in pixels
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
// Phase 1: Decode all tiles and store in 2D grid
var rows [][]decodedTile
for i := int32(0); i < H; i += overlapSize {
var row []decodedTile
for j := int32(0); j < W; j += overlapSize {
// Extract tile (may be smaller at edges)
i2 := min(i+tileLatentSize, H)
j2 := min(j+tileLatentSize, W)
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
decoded := decoder(tile)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
mlx.Eval(decoded)
decodedShape := decoded.Shape()
tileH := decodedShape[1]
tileW := decodedShape[2]
tileData := decoded.Data()
decoded.Free()
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
}
rows = append(rows, row)
}
// Phase 2: Blend adjacent tiles (modifies in place)
for i := range rows {
for j := range rows[i] {
tile := &rows[i][j]
// Blend with tile above
if i > 0 {
above := &rows[i-1][j]
blendV(above, tile, blendExtent)
}
// Blend with tile to the left
if j > 0 {
left := &rows[i][j-1]
blendH(left, tile, blendExtent)
}
}
}
// Phase 3: Calculate crop dimensions for each tile
colWidths := make([]int32, len(rows[0]))
for j := range rows[0] {
keepW := rowLimit
if int32(j+1)*overlapSize >= W {
keepW = rows[0][j].width
}
colWidths[j] = keepW
}
rowHeights := make([]int32, len(rows))
for i := range rows {
keepH := rowLimit
if int32(i+1)*overlapSize >= H {
keepH = rows[i][0].height
}
rowHeights[i] = keepH
}
// Calculate total dimensions
var totalW, totalH int32
for _, w := range colWidths {
totalW += w
}
for _, h := range rowHeights {
totalH += h
}
// Phase 4: Assemble final image by interleaving tiles row-by-row
finalData := make([]float32, totalH*totalW*3)
dstY := int32(0)
for i, row := range rows {
keepH := rowHeights[i]
for y := int32(0); y < keepH; y++ {
dstX := int32(0)
for j, tile := range row {
keepW := colWidths[j]
for x := int32(0); x < keepW; x++ {
for c := int32(0); c < 3; c++ {
srcIdx := (y*tile.width + x) * 3 + c
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
finalData[dstIdx] = tile.data[srcIdx]
}
}
dstX += keepW
}
}
dstY += keepH
}
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
result = mlx.Transpose(result, 0, 3, 1, 2)
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
return result
}
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
// Matches diffusers blend_v formula
func blendV(above, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(above.height, current.height))
if blend <= 0 {
return
}
w := min(above.width, current.width)
for y := int32(0); y < blend; y++ {
alpha := float32(y) / float32(blend)
for x := int32(0); x < w; x++ {
for c := int32(0); c < 3; c++ {
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
// Matches diffusers blend_h formula
func blendH(left, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(left.width, current.width))
if blend <= 0 {
return
}
h := min(left.height, current.height)
for y := int32(0); y < h; y++ {
for x := int32(0); x < blend; x++ {
alpha := float32(x) / float32(blend)
for c := int32(0); c < 3; c++ {
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}