Compare commits

..

3 Commits

Author SHA1 Message Date
Parth Sareen
2185112d84 x/cmd: connect /set flags to behavior in experimental mode (#13684) 2026-01-12 00:40:44 -08:00
Parth Sareen
91926601dc x: add missing /set, /show, /load, /save commands to experimental mode (#13682) 2026-01-11 23:12:31 -08:00
Jeffrey Morgan
361d6c16c2 x/imagegen/transfer: fix timeout and progress reporting (#13679)
Removes 5-minute HTTP client timeout that caused "context deadline
exceeded" errors on large file downloads. Stall detection (10s)
already handles unresponsive connections.

Fixes progress bar total going down on resume by calculating total
from all blobs upfront and reporting already-downloaded bytes
as completed immediately.
2026-01-11 15:33:53 -08:00
30 changed files with 697 additions and 1153 deletions

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -100,8 +100,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -465,6 +464,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -526,14 +533,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
@@ -672,11 +671,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -842,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err

View File

@@ -1,3 +0,0 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -1124,15 +1124,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1215,10 +1206,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"strings"
"syscall"
"time"
@@ -24,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -130,6 +139,7 @@ type RunOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
Verbose bool
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
@@ -178,6 +188,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
var latest api.ChatResponse
role := "assistant"
messages := opts.Messages
@@ -187,6 +198,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
}
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
@@ -483,6 +495,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Println()
}
if opts.Verbose {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
@@ -677,6 +693,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -688,13 +707,39 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -707,6 +752,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load Load a different model")
fmt.Fprintln(os.Stderr, " /save Save session as a model")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
@@ -716,6 +765,303 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
thinkValue.Value = maybeLevel
}
think = &thinkValue
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
think = &api.ThinkValue{Value: false}
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
}
continue
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.ShowRequest{
Name: modelName,
Options: options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
continue
}
switch args[1] {
case "info":
fmt.Fprintf(os.Stderr, " Model\n")
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
if resp.Details.Family != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
}
if resp.Details.ParameterSize != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
}
if resp.Details.QuantizationLevel != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
}
if len(resp.Capabilities) > 0 {
caps := make([]string, len(resp.Capabilities))
for i, c := range resp.Capabilities {
caps[i] = string(c)
}
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
}
fmt.Fprintln(os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
if len(options) > 0 {
fmt.Println("\nUser defined parameters:")
for k, v := range options {
fmt.Printf(" %-30s %v\n", k, v)
}
}
case "system":
switch {
case system != "":
fmt.Println(system + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
}
continue
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModelName := args[1]
fmt.Printf("Loading model '%s'\n", newModelName)
// Create progress spinner
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// Get client
client, err := api.ClientFromEnvironment()
if err != nil {
p.StopAndClear()
fmt.Println("error: couldn't connect to ollama server")
continue
}
// Check if model exists and get its info
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else {
fmt.Printf("error: %v\n", err)
}
continue
}
// For cloud models, no need to preload
if info.RemoteHost == "" {
// Preload the model by sending an empty generate request
req := &api.GenerateRequest{
Model: newModelName,
Think: think,
}
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
return nil
})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
} else {
fmt.Printf("error loading model: %v\n", err)
}
continue
}
}
p.StopAndClear()
modelName = newModelName
messages = []api.Message{}
approval.Reset()
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.CreateRequest{
Model: args[1],
From: modelName,
Parameters: options,
Messages: messages,
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Printf("error: %v\n", err)
continue
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
@@ -723,14 +1069,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 {
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
verbose, _ := cmd.Flags().GetBool("verbose")
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Format: format,
Options: options,
Think: think,
HideThinking: hideThinking,
@@ -738,6 +1086,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
Verbose: verbose,
}
assistant, err := Chat(cmd.Context(), opts)

View File

@@ -234,17 +234,3 @@ ollama create z-image
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,8 +1,10 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
@@ -99,10 +101,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
imagePath = extractPath(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
@@ -116,14 +118,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
c.SSEvent("done", buildResponse(imagePath, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
imagePath = extractPath(resp.Content)
}
})
if err != nil {
@@ -131,7 +133,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
c.JSON(http.StatusOK, buildResponse(imagePath, format))
}
func parseSize(size string) (int32, int32) {
@@ -150,9 +152,9 @@ func parseSize(size string) (int32, int32) {
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
}
return ""
}
@@ -163,21 +165,23 @@ func parseProgress(content string) ImageProgressEvent {
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
func buildResponse(imagePath, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
if imagePath == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
resp.Data[0].URL = "file://" + imagePath
} else {
resp.Data[0].B64JSON = imageBase64
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
}
return resp

View File

@@ -1,197 +0,0 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -44,64 +44,62 @@ func DefaultOptions() ImageGenOptions {
}
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
return fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture, parameter count, and quantization
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
architecture = index.Architecture
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
return info, nil
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
}
return fmt.Sprintf("%d", count)
}
// RegisterFlags adds image generation flags to the given command.
@@ -185,7 +183,8 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -204,9 +203,11 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
@@ -217,27 +218,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
}
return nil
@@ -323,7 +306,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -343,9 +326,11 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
@@ -357,30 +342,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
continue
}
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
@@ -401,6 +381,24 @@ func sanitizeFilename(s string) string {
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")

View File

@@ -29,10 +29,9 @@ const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
@@ -59,77 +58,18 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
return imagegen.LayerInfo{}, err
}
layer.Name = name
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
@@ -179,7 +119,7 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err

View File

@@ -1,120 +0,0 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -1,18 +0,0 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -67,9 +67,6 @@ func main() {
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
@@ -102,17 +99,13 @@ func main() {
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)

View File

@@ -40,12 +40,10 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
@@ -76,11 +74,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
@@ -89,30 +83,16 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, newLayers...)
layers = append(layers, layer)
}
extractor.Close()
@@ -142,7 +122,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata
// For model_index.json, normalize to Ollama format
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
@@ -161,16 +141,6 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)

View File

@@ -60,12 +60,9 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
}
// Transform to [H, W, C] for image conversion
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()

View File

@@ -607,11 +607,6 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
@@ -1485,44 +1480,6 @@ func (a *Array) ItemInt32() int32 {
return int32(val)
}
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============
// String returns a string representation
@@ -1701,34 +1658,6 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata)
}
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
@@ -2057,8 +1986,7 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default), "mxfp4", or "mxfp8"
// Note: mxfp8 mode returns nil biases (only weights and scales)
// mode: "affine" (default) or "mxfp4"
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -2067,21 +1995,14 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
// Result is a vector of 3 arrays: [weights, scales, biases]
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
if vecSize >= 3 {
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_get(&w2, res, 2)
C.mlx_vector_array_free(res)
if vecSize >= 3 {
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
return newArray(w0), newArray(w1), newArray(w2)
}
// Dequantize reconstructs weights from quantized form.

View File

@@ -222,14 +222,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -272,19 +264,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -322,9 +305,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {

View File

@@ -241,14 +241,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -299,18 +291,11 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
@@ -332,9 +317,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()

View File

@@ -28,12 +28,12 @@ type Qwen3Config struct {
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
}
// Forward applies the MLP

View File

@@ -36,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear nn.LinearLayer `weight:"2-1"`
Linear *nn.Linear `weight:"2-1"`
}
// Forward embeds patchified image latents
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
@@ -100,13 +100,12 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -116,7 +115,6 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
@@ -128,69 +126,17 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
}
// Forward computes attention
@@ -200,24 +146,11 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -294,7 +227,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -348,8 +281,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
@@ -417,11 +350,12 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
@@ -443,7 +377,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
@@ -457,20 +391,6 @@ func (m *Transformer) initComputedFields() {
}
}
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
@@ -484,7 +404,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
@@ -503,8 +423,6 @@ type RoPECache struct {
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -538,8 +456,6 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}

View File

@@ -104,8 +104,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
@@ -207,10 +205,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
@@ -740,26 +734,18 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev = h
prev := h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
@@ -768,6 +754,7 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h)
return h

View File

@@ -26,12 +26,10 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
}
// ProgressFunc is called during generation with step progress.
@@ -44,7 +42,6 @@ type Model struct {
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from ollama blob storage.
@@ -199,17 +196,13 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
}
useCFG := cfg.NegativePrompt != ""
@@ -267,54 +260,12 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(ropeCache.UnifiedCos)
}
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
@@ -326,7 +277,6 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
@@ -339,77 +289,50 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
}
tCurr := scheduler.Timesteps[i]
var noisePred *mlx.Array
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
patches := PatchifyLatents(latents, tcfg.PatchSize)
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
output = mlx.Add(negOutput, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
@@ -438,14 +361,8 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode

View File

@@ -10,13 +10,6 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
@@ -56,11 +49,6 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w)
}
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -96,13 +84,6 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out
}
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`

View File

@@ -1,22 +0,0 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -13,6 +13,7 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
@@ -33,8 +34,7 @@ type Request struct {
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Content string `json:"content"`
Done bool `json:"done"`
}
@@ -191,10 +191,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
// Save image
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
@@ -204,12 +204,11 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
// Send final response
resp := Response{
Image: imageData,
Done: true,
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)

View File

@@ -8,7 +8,6 @@ import (
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// WeightSource is an interface for loading weights.
@@ -103,22 +102,6 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -193,64 +176,3 @@ func joinPath(prefix, suffix string) string {
}
return prefix + "." + suffix
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -46,8 +46,7 @@ type completionRequest struct {
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Content string `json:"content"`
Done bool `json:"done"`
}
@@ -251,23 +250,15 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
}
// Stream responses - use large buffer for base64 image data
// Stream responses
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue
}
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
}
fn(llm.CompletionResponse{
Content: content,
Content: cresp.Content,
Done: cresp.Done,
})
if cresp.Done {

View File

@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
return nil
}
// Filter existing
var blobs []Blob
// Calculate total from all blobs (for accurate progress reporting on resume)
var total int64
for _, b := range opts.Blobs {
total += b.Size
}
// Filter out already-downloaded blobs and track completed bytes
var blobs []Blob
var alreadyCompleted int64
for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
}
alreadyCompleted += b.Size
continue
}
blobs = append(blobs, b)
total += b.Size
}
if len(blobs) == 0 {
return nil
}
token := opts.Token
progress := newProgressTracker(total, opts.Progress)
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
d := &downloader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: newProgressTracker(total, opts.Progress),
progress: progress,
speeds: &speedTracker{},
logger: opts.Logger,
}

View File

@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
},
Timeout: 5 * time.Minute,
// Don't follow redirects automatically - we handle them manually
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},

View File

@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
}
}
func TestDownloadResumeProgressTotal(t *testing.T) {
// Test that when resuming a download with some blobs already present:
// 1. Total reflects ALL blob sizes (not just remaining)
// 2. Completed starts at the size of already-downloaded blobs
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1000)
blob2, data2 := createTestBlob(t, serverDir, 2000)
blob3, data3 := createTestBlob(t, serverDir, 3000)
// Pre-populate client with blob1 and blob2 (simulating partial download)
clientDir := t.TempDir()
for _, b := range []struct {
blob Blob
data []byte
}{{blob1, data1}, {blob2, data2}} {
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, b.data, 0o644); err != nil {
t.Fatal(err)
}
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
var firstCompleted, firstTotal int64
var gotFirstProgress bool
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
Progress: func(completed, total int64) {
mu.Lock()
defer mu.Unlock()
if !gotFirstProgress {
firstCompleted = completed
firstTotal = total
gotFirstProgress = true
}
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Total should be sum of ALL blobs, not just blob3
expectedTotal := blob1.Size + blob2.Size + blob3.Size
if firstTotal != expectedTotal {
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
}
// First progress call should show already-completed bytes from blob1+blob2
expectedCompleted := blob1.Size + blob2.Size
if firstCompleted < expectedCompleted {
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
}
// Verify blob3 was downloaded
verifyBlob(t, clientDir, blob3, data3)
}
func TestDownloadDigestMismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return wrong data