mirror of
https://github.com/ollama/ollama.git
synced 2026-02-20 08:16:07 -05:00
216 lines
6.2 KiB
Go
216 lines
6.2 KiB
Go
//go:build mlx
|
|
|
|
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
|
|
package vae
|
|
|
|
import (
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// TilingConfig holds configuration for tiled VAE decoding.
|
|
// This is a general technique to reduce memory usage when decoding large latents.
|
|
type TilingConfig struct {
|
|
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
|
|
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
|
|
}
|
|
|
|
// DefaultTilingConfig returns reasonable defaults matching diffusers.
|
|
// tile_latent_min_size=64, tile_overlap_factor=0.25
|
|
func DefaultTilingConfig() *TilingConfig {
|
|
return &TilingConfig{
|
|
TileSize: 64, // 64 latent pixels
|
|
Overlap: 16, // 25% overlap
|
|
}
|
|
}
|
|
|
|
// decodedTile holds a decoded tile's pixel data and dimensions
|
|
type decodedTile struct {
|
|
data []float32
|
|
height int32
|
|
width int32
|
|
}
|
|
|
|
// DecodeTiled decodes latents using tiled processing with overlap blending.
|
|
// This reduces memory usage for large images by processing in overlapping tiles.
|
|
//
|
|
// Parameters:
|
|
// - latents: [1, H, W, C] latent tensor in NHWC format
|
|
// - cfg: tiling configuration (tile size and overlap)
|
|
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
|
|
//
|
|
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
|
|
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
|
|
shape := latents.Shape()
|
|
H := shape[1] // latent height
|
|
W := shape[2] // latent width
|
|
C := shape[3]
|
|
|
|
tileLatentSize := cfg.TileSize
|
|
overlapLatent := cfg.Overlap
|
|
|
|
// If image is small enough, just decode normally
|
|
if H <= tileLatentSize && W <= tileLatentSize {
|
|
decoded := decoder(latents)
|
|
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
|
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
|
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
|
|
return decoded
|
|
}
|
|
|
|
// Calculate tiling parameters (matching diffusers)
|
|
overlapSize := tileLatentSize - overlapLatent // stride in latent space
|
|
|
|
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
|
|
// For other scale factors, this could be made configurable
|
|
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
|
|
blendExtent := overlapLatent * 8 // blend region in pixels
|
|
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
|
|
|
|
// Phase 1: Decode all tiles and store in 2D grid
|
|
var rows [][]decodedTile
|
|
|
|
for i := int32(0); i < H; i += overlapSize {
|
|
var row []decodedTile
|
|
for j := int32(0); j < W; j += overlapSize {
|
|
// Extract tile (may be smaller at edges)
|
|
i2 := min(i+tileLatentSize, H)
|
|
j2 := min(j+tileLatentSize, W)
|
|
|
|
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
|
|
decoded := decoder(tile)
|
|
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
|
mlx.Eval(decoded)
|
|
|
|
decodedShape := decoded.Shape()
|
|
tileH := decodedShape[1]
|
|
tileW := decodedShape[2]
|
|
tileData := decoded.Data()
|
|
decoded.Free()
|
|
|
|
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
|
|
}
|
|
rows = append(rows, row)
|
|
}
|
|
|
|
// Phase 2: Blend adjacent tiles (modifies in place)
|
|
for i := range rows {
|
|
for j := range rows[i] {
|
|
tile := &rows[i][j]
|
|
|
|
// Blend with tile above
|
|
if i > 0 {
|
|
above := &rows[i-1][j]
|
|
blendV(above, tile, blendExtent)
|
|
}
|
|
|
|
// Blend with tile to the left
|
|
if j > 0 {
|
|
left := &rows[i][j-1]
|
|
blendH(left, tile, blendExtent)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Phase 3: Calculate crop dimensions for each tile
|
|
colWidths := make([]int32, len(rows[0]))
|
|
for j := range rows[0] {
|
|
keepW := rowLimit
|
|
if int32(j+1)*overlapSize >= W {
|
|
keepW = rows[0][j].width
|
|
}
|
|
colWidths[j] = keepW
|
|
}
|
|
|
|
rowHeights := make([]int32, len(rows))
|
|
for i := range rows {
|
|
keepH := rowLimit
|
|
if int32(i+1)*overlapSize >= H {
|
|
keepH = rows[i][0].height
|
|
}
|
|
rowHeights[i] = keepH
|
|
}
|
|
|
|
// Calculate total dimensions
|
|
var totalW, totalH int32
|
|
for _, w := range colWidths {
|
|
totalW += w
|
|
}
|
|
for _, h := range rowHeights {
|
|
totalH += h
|
|
}
|
|
|
|
// Phase 4: Assemble final image by interleaving tiles row-by-row
|
|
finalData := make([]float32, totalH*totalW*3)
|
|
|
|
dstY := int32(0)
|
|
for i, row := range rows {
|
|
keepH := rowHeights[i]
|
|
|
|
for y := int32(0); y < keepH; y++ {
|
|
dstX := int32(0)
|
|
for j, tile := range row {
|
|
keepW := colWidths[j]
|
|
|
|
for x := int32(0); x < keepW; x++ {
|
|
for c := int32(0); c < 3; c++ {
|
|
srcIdx := (y*tile.width + x) * 3 + c
|
|
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
|
|
finalData[dstIdx] = tile.data[srcIdx]
|
|
}
|
|
}
|
|
dstX += keepW
|
|
}
|
|
}
|
|
dstY += keepH
|
|
}
|
|
|
|
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
|
|
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
|
|
result = mlx.Transpose(result, 0, 3, 1, 2)
|
|
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
|
|
|
|
return result
|
|
}
|
|
|
|
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
|
|
// Matches diffusers blend_v formula
|
|
func blendV(above, current *decodedTile, blendExtent int32) {
|
|
blend := min(blendExtent, min(above.height, current.height))
|
|
if blend <= 0 {
|
|
return
|
|
}
|
|
|
|
w := min(above.width, current.width)
|
|
for y := int32(0); y < blend; y++ {
|
|
alpha := float32(y) / float32(blend)
|
|
for x := int32(0); x < w; x++ {
|
|
for c := int32(0); c < 3; c++ {
|
|
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
|
|
currIdx := (y * current.width + x) * 3 + c
|
|
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
|
|
// Matches diffusers blend_h formula
|
|
func blendH(left, current *decodedTile, blendExtent int32) {
|
|
blend := min(blendExtent, min(left.width, current.width))
|
|
if blend <= 0 {
|
|
return
|
|
}
|
|
|
|
h := min(left.height, current.height)
|
|
for y := int32(0); y < h; y++ {
|
|
for x := int32(0); x < blend; x++ {
|
|
alpha := float32(x) / float32(blend)
|
|
for c := int32(0); c < 3; c++ {
|
|
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
|
|
currIdx := (y * current.width + x) * 3 + c
|
|
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
|
|
}
|
|
}
|
|
}
|
|
}
|