Compare commits

..

13 Commits

Author SHA1 Message Date
ParthSareen
14499406d2 anthropic: fix ToolCallFunctionArguments type after rebase
Update tests and implementation to use the new ordered map-based
ToolCallFunctionArguments type which replaces the previous map[string]any.

- Add mapToArgs helper to convert map[string]any to ToolCallFunctionArguments
- Add testArgs and testProps helpers in tests
- Use cmpopts.IgnoreUnexported for cmp.Diff comparisons
2026-01-06 13:14:06 -08:00
ParthSareen
2b5093e2e7 middleware: use HTTP status code for Anthropic error mapping
Use w.ResponseWriter.Status() instead of parsing StatusCode from JSON
payload. routes.go typically sends errors as gin.H{"error": "..."}
without a StatusCode field, causing all errors to be mapped to
"api_error" instead of the appropriate type (not_found_error,
invalid_request_error, etc.).

Added tests to verify error handling for common routes.go patterns.
2026-01-06 13:14:06 -08:00
ParthSareen
f4537dd113 docs: add JavaScript example for tool calling 2026-01-06 13:14:06 -08:00
ParthSareen
34257b6e37 middleware: fix test for pointer type Text field 2026-01-06 13:14:06 -08:00
ParthSareen
01d12cd98f anthropic: use pointer types for Text and Thinking fields
Use *string instead of string for Text and Thinking fields in ContentBlock
so that omitempty works correctly:
- nil pointer: field omitted from JSON (for blocks that don't use it)
- ptr(""): field present as "" (for SDK streaming accumulation)
- ptr("content"): field present with content

This keeps the JSON output clean (text blocks don't have thinking field,
thinking blocks don't have text field) while still satisfying SDK
requirements for field presence during streaming.
2026-01-06 13:14:06 -08:00
ParthSareen
9a16ad3857 anthropic: preserve messages with only thinking content
Fix edge case where messages containing only a thinking block (no text,
images, or tool calls) would be dropped. Add thinking != "" to the
condition that creates messages from content blocks.
2026-01-06 13:14:06 -08:00
ParthSareen
b14ba5285f docs: add Claude Code integration guide 2026-01-06 13:14:06 -08:00
ParthSareen
4ec2873ed1 anthropic: add tests for SDK-required empty fields
Add tests documenting that Text and Thinking fields must be present
in JSON output even when empty. The Anthropic SDK requires these fields
in content_block_start events to accumulate streaming deltas properly.

Tests verify:
- ContentBlock JSON includes empty text/thinking fields
- StreamConverter emits content_block_start with required fields
2026-01-06 13:14:06 -08:00
ParthSareen
e9f4ef84fb anthropic: fix streaming with SDK by including empty fields
Remove omitempty from Text and Thinking fields in ContentBlock struct.
The Anthropic SDK requires these fields to be present (even if empty)
in content_block_start events to properly accumulate streaming deltas.
2026-01-06 13:14:06 -08:00
ParthSareen
3531fb5970 anthropic: remove redundant comments
Remove obvious comments that don't add value (e.g., "// Convert messages",
"// Handle done"). Keep godoc comments and those explaining API mappings.
2026-01-06 13:14:06 -08:00
ParthSareen
f5a85e8ac6 anthropic: fix error handling and update docs
- Add proper error handling for JSON marshal in StreamConverter to
  prevent corrupted streams when tool arguments cannot be serialized
- Add tests for unmarshalable arguments and mixed validity scenarios
- Fix documentation typo and update recommended models to qwen3-coder
2026-01-06 13:14:06 -08:00
ParthSareen
214563ab17 anthropic: add unit and integration tests
- Unit tests for transformation functions (FromMessagesRequest, ToMessagesResponse)
- Unit tests for error handling and edge cases
- Middleware integration tests with httptest
- Fix lint issues (gofmt)
- Fix unused struct fields in StreamConverter
- Add fallback for crypto/rand errors
2026-01-06 13:14:06 -08:00
ParthSareen
2b90199b91 api: add Anthropic Messages API compatibility layer
Add middleware to support the Anthropic Messages API format at /v1/messages.
This enables tools like Claude Code to work with Ollama models through the
Anthropic API interface.

Features:
- Request/response transformation between Anthropic and internal formats
- Streaming support with SSE events (message_start, content_block_delta, etc.)
- Tool calling support (tool_use and tool_result content blocks)
- Thinking/extended thinking block support
- Image content block support (base64)
- System prompt handling
- Multi-turn conversation support
- Proper stop_reason mapping (end_turn, max_tokens, tool_use)
- Error responses in Anthropic format

New files:
- anthropic/anthropic.go: Types and transformation functions
- middleware/anthropic.go: Request/response middleware
2026-01-06 13:14:06 -08:00
145 changed files with 3084 additions and 28905 deletions

View File

@@ -12,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -147,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,40 +131,7 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx-engine .
RUN go build -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS build
@@ -186,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

778
anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,778 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

953
anthropic/anthropic_test.go Normal file
View File

@@ -0,0 +1,953 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -520,7 +520,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -548,9 +547,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
// Use experimental agent loop with
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
}
return generateInteractive(cmd, opts)
@@ -1765,7 +1764,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,28 +18,9 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
// TODO vision config
/*
"vision_config": {
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": false
}
*/
}
type AdapterParameters struct {
@@ -55,91 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -168,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -176,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -193,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -217,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -225,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -236,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -263,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -323,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -360,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -382,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,6 +3,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -53,7 +55,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -14,6 +14,7 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -0,0 +1,406 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -32,7 +32,9 @@
"codeblocks": "system"
},
"contextual": {
"options": ["copy"]
"options": [
"copy"
]
},
"navbar": {
"links": [
@@ -52,7 +54,9 @@
"display": "simple"
},
"examples": {
"languages": ["curl"]
"languages": [
"curl"
]
}
},
"redirects": [
@@ -97,6 +101,7 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -139,7 +144,8 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility"
"/api/openai-compatibility",
"/api/anthropic-compatibility"
]
},
{

View File

@@ -0,0 +1,69 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}

149
middleware/anthropic.go Normal file
View File

@@ -0,0 +1,149 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -0,0 +1,584 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}

View File

@@ -21,7 +21,6 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -802,8 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV
base = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -6,9 +6,6 @@ import (
var ErrInterrupt = errors.New("Interrupt")
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
var ErrExpandOutput = errors.New("ExpandOutput")
type InterruptError struct {
Line []rune
}

View File

@@ -206,9 +206,6 @@ func (i *Instance) Readline() (string, error) {
buf.DeleteBefore()
case CharCtrlL:
buf.ClearScreen()
case CharCtrlO:
// Ctrl+O - expand tool output
return "", ErrExpandOutput
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:

View File

@@ -18,7 +18,6 @@ const (
CharCtrlL = 12
CharEnter = 13
CharNext = 14
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
CharPrev = 16
CharBckSearch = 18
CharFwdSearch = 19

View File

@@ -37,55 +37,6 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then

View File

@@ -26,7 +26,6 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -455,7 +454,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil

View File

@@ -1544,6 +1544,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -22,7 +22,6 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -42,8 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV
base = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -381,28 +381,6 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -418,11 +396,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -435,7 +413,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments templateArgs
Arguments map[string]any
}
// templateMessage is a template-compatible representation of api.Message
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
Properties: tool.Function.Parameters.Properties.ToMap(),
},
},
}
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
Arguments: tc.Function.Arguments.ToMap(),
},
})
}

View File

@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -4,7 +4,6 @@ package agent
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
@@ -180,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
@@ -206,8 +204,8 @@ func extractBashPrefix(command string) string {
return ""
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
@@ -217,49 +215,19 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
continue
}
// Normalize to forward slashes for consistent cross-platform matching
arg = strings.ReplaceAll(arg, "\\", "/")
// Security: reject absolute paths
if path.IsAbs(arg) {
return "" // Absolute path - don't create prefix
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
}
// Normalize the path using stdlib path.Clean (resolves . and ..)
cleaned := path.Clean(arg)
// Security: reject if cleaned path escapes to parent directory
if strings.HasPrefix(cleaned, "..") {
return "" // Path escapes - don't create prefix
}
// Security: if original had "..", verify cleaned path didn't escape to sibling
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
if strings.Contains(arg, "..") {
origBase := strings.SplitN(arg, "/", 2)[0]
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
if origBase != cleanedBase {
return "" // Path escaped to sibling directory
}
}
// Check if arg ends with / (explicit directory)
isDir := strings.HasSuffix(arg, "/")
// Get the directory part
var dir string
if isDir {
dir = cleaned
} else {
dir = path.Dir(cleaned)
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
@@ -364,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
@@ -376,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return true
}
// For bash commands, check prefix matches with hierarchical path support
// For bash commands, check prefix matches
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Check exact prefix match first
if a.prefixes[prefix] {
return true
}
// Check hierarchical match: if any stored prefix is a parent of current prefix
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
if a.matchesHierarchicalPrefix(prefix) {
return true
}
if prefix != "" && a.prefixes[prefix] {
return true
}
}
}
@@ -402,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return false
}
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
// Split prefix into command and path parts (format: "cmd:path/")
colonIdx := strings.Index(currentPrefix, ":")
if colonIdx == -1 {
return false
}
currentCmd := currentPrefix[:colonIdx]
currentPath := currentPrefix[colonIdx+1:]
for storedPrefix := range a.prefixes {
storedColonIdx := strings.Index(storedPrefix, ":")
if storedColonIdx == -1 {
continue
}
storedCmd := storedPrefix[:storedColonIdx]
storedPath := storedPrefix[storedColonIdx+1:]
// Commands must match exactly
if currentCmd != storedCmd {
continue
}
// Check if current path starts with stored path (hierarchical match)
// e.g., "tools/subdir/" starts with "tools/"
if strings.HasPrefix(currentPath, storedPath) {
return true
}
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
@@ -519,12 +443,11 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
}
// For web search, show query and internet notice
// For web search, show query
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
sb.WriteString(fmt.Sprintf("Query: %s", query))
return sb.String()
}
}
@@ -1028,79 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
// Returns true for Yes, false for No.
func PromptYesNo(question string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
selected := 0 // 0 = Yes, 1 = No
options := []string{"Yes", "No"}
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
}
}
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
}
renderYesNo()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return false, err
}
if n == 1 {
switch buf[0] {
case 'y', 'Y':
selected = 0
renderYesNo()
case 'n', 'N':
selected = 1
renderYesNo()
case '\r', '\n': // Enter
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
return selected == 0, nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\r\033[K")
return false, nil
case 27: // Escape - could be arrow key
// Read more bytes for arrow keys
continue
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
// Arrow keys
switch buf[2] {
case 'D': // Left
if selected > 0 {
selected--
}
renderYesNo()
case 'C': // Right
if selected < len(options)-1 {
selected++
}
renderYesNo()
}
}
}
}

View File

@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
command: "head -n 100",
expected: "",
},
// Path traversal security tests
{
name: "path traversal - parent escape",
command: "cat tools/../../etc/passwd",
expected: "", // Should NOT create a prefix - path escapes
},
{
name: "path traversal - deep escape",
command: "cat tools/a/b/../../../etc/passwd",
expected: "", // Normalizes to "../etc/passwd" - escapes
},
{
name: "path traversal - absolute path",
command: "cat /etc/passwd",
expected: "", // Absolute paths should not create prefix
},
{
name: "path with safe dotdot - normalized",
command: "cat tools/subdir/../file.go",
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
},
}
for _, tt := range tests {
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
}
}
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Path traversal attack: should NOT be allowed
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
t.Error("SECURITY: path traversal attack should NOT be allowed")
}
// Another traversal variant
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
t.Error("SECURITY: deep path traversal should NOT be allowed")
}
// Valid subdirectory access should still work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed")
}
// Safe ".." that normalizes to within allowed directory should work
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow subdirectories (hierarchical matching)
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
}
// Should allow deeply nested subdirectories
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
}
// Should still allow same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
t.Error("expected cat tools/another.go to be allowed")
}
// Should NOT allow different base directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should NOT allow different command even in subdirectory
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
}
// Should NOT allow similar but different directory name
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
am := NewApprovalManager()
// Allow with forward slashes (Unix-style)
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should work with backslashes too (Windows-style) - normalized internally
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
}
// Mixed slashes should also work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
}
}
func TestMatchesHierarchicalPrefix(t *testing.T) {
am := NewApprovalManager()
// Add prefix for "cat:tools/"
am.prefixes["cat:tools/"] = true
tests := []struct {
name string
prefix string
expected bool
}{
{
name: "exact match",
prefix: "cat:tools/",
expected: true, // exact match also passes HasPrefix - caller handles exact match first
},
{
name: "subdirectory",
prefix: "cat:tools/subdir/",
expected: true,
},
{
name: "deeply nested",
prefix: "cat:tools/a/b/c/",
expected: true,
},
{
name: "different base directory",
prefix: "cat:src/",
expected: false,
},
{
name: "different command same path",
prefix: "ls:tools/",
expected: false,
},
{
name: "similar directory name",
prefix: "cat:toolsbin/",
expected: false,
},
{
name: "invalid prefix format",
prefix: "cattools",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := am.matchesHierarchicalPrefix(tt.prefix)
if result != tt.expected {
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
tt.prefix, result, tt.expected)
}
})
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,12 +6,10 @@ import (
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -24,101 +22,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
localModelTokenLimit = 4000
// defaultTokenLimit is the token limit for cloud/remote models.
defaultTokenLimit = 10000
// charsPerToken is a rough estimate of characters per token.
// TODO: Estimate tokens more accurately using tokenizer if available
charsPerToken = 4
)
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
}
// isLocalServer checks if connecting to a local Ollama server.
// TODO: Could also check other indicators of local vs cloud server
func isLocalServer() bool {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return true // Default is localhost:11434
}
// Parse the URL to check host
parsed, err := url.Parse(host)
if err != nil {
return true // If can't parse, assume local
}
hostname := parsed.Hostname()
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
var tokenLimit int
if isLocalModel(modelName) && isLocalServer() {
tokenLimit = localModelTokenLimit
} else {
tokenLimit = defaultTokenLimit
}
maxChars := tokenLimit * charsPerToken
if len(output) > maxChars {
return output[:maxChars] + "\n... (output truncated)"
}
return output
}
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
func waitForOllamaSignin(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\n")
return ctx.Err()
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
return nil
}
// Still waiting, show dot
fmt.Fprintf(os.Stderr, ".")
}
}
}
return err
}
return nil
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -134,16 +37,6 @@ type RunOptions struct {
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
// YoloMode skips all tool approval prompts
YoloMode bool
// LastToolOutput stores the full output of the last tool execution
// for Ctrl+O expansion. Updated by Chat(), read by caller.
LastToolOutput *string
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
}
// Chat runs an agent chat loop with tool support.
@@ -184,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
role := "assistant"
messages := opts.Messages
@@ -267,58 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue // Retry the loop
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
}
// Check for 500 errors (often tool parsing failures) - inform the model
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
consecutiveErrors++
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
if assistantContent == "" {
assistantContent = "(empty response)"
}
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
messages = append(messages,
api.Message{Role: "user", Content: errorMsg},
)
// Reset state and retry
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
continue
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
@@ -328,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, err
}
// Reset consecutive error counter on success
consecutiveErrors = 0
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
@@ -379,12 +216,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check approval (uses prefix matching for bash commands)
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
@@ -418,23 +250,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
}
}
}
}
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
@@ -443,34 +258,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
})
continue
}
toolSuccess:
// Display tool output (truncated for display)
truncatedOutput := ""
if toolResult != "" {
output := toolResult
if len(output) > 300 {
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
output = output[:300] + "... (truncated)"
}
truncatedOutput = output
// Show result in grey, indented
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Store full and truncated output for Ctrl+O toggle
if opts.LastToolOutput != nil {
*opts.LastToolOutput = toolResult
}
if opts.LastToolOutputTruncated != nil {
*opts.LastToolOutputTruncated = truncatedOutput
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResultForLLM,
Content: toolResult,
ToolCallID: call.ID,
})
}
@@ -648,8 +449,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -674,11 +474,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
if toolRegistry.Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
@@ -690,11 +490,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
// Track last tool output for Ctrl+O toggle
var lastToolOutput string
var lastToolOutputTruncated string
var toolOutputExpanded bool
for {
line, err := scanner.Readline()
switch {
@@ -707,20 +502,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
sb.Reset()
continue
case errors.Is(err, readline.ErrExpandOutput):
// Ctrl+O pressed - toggle between expanded and collapsed tool output
if lastToolOutput == "" {
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
} else if toolOutputExpanded {
// Currently expanded, show truncated
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
toolOutputExpanded = false
} else {
// Currently collapsed, show full
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
toolOutputExpanded = true
}
continue
case err != nil:
return err
}
@@ -743,9 +524,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
@@ -759,21 +537,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
messages = append(messages, newMessage)
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
}
// Reset expanded state for new tool execution
toolOutputExpanded = false
assistant, err := Chat(cmd.Context(), opts)
if err != nil {

View File

@@ -1,180 +0,0 @@
package cmd
import (
"testing"
)
func TestIsLocalModel(t *testing.T) {
tests := []struct {
name string
modelName string
expected bool
}{
{
name: "local model without suffix",
modelName: "llama3.2",
expected: true,
},
{
name: "local model with version",
modelName: "qwen2.5:7b",
expected: true,
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
expected: false,
},
{
name: "empty model name",
modelName: "",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalModel(tt.modelName)
if result != tt.expected {
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
}
})
}
}
func TestIsLocalServer(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{
name: "empty host (default)",
host: "",
expected: true,
},
{
name: "localhost",
host: "http://localhost:11434",
expected: true,
},
{
name: "127.0.0.1",
host: "http://127.0.0.1:11434",
expected: true,
},
{
name: "custom port on localhost",
host: "http://localhost:8080",
expected: true, // localhost is always considered local
},
{
name: "remote host",
host: "http://ollama.example.com:11434",
expected: true, // has :11434
},
{
name: "remote host different port",
host: "http://ollama.example.com:8080",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := isLocalServer()
if result != tt.expected {
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
}
})
}
}
func TestTruncateToolOutput(t *testing.T) {
// Create outputs of different sizes
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
for i := range localLimitOutput {
localLimitOutput[i] = 'a'
}
for i := range defaultLimitOutput {
defaultLimitOutput[i] = 'b'
}
tests := []struct {
name string
output string
modelName string
host string
shouldTrim bool
expectedLimit int
}{
{
name: "short output local model",
output: "hello world",
modelName: "llama3.2",
host: "",
shouldTrim: false,
expectedLimit: localModelTokenLimit,
},
{
name: "long output local model - trimmed at 4k",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "",
shouldTrim: true,
expectedLimit: localModelTokenLimit,
},
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,
},
{
name: "long output remote server - uses 10k limit",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "http://remote.example.com:8080",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := truncateToolOutput(tt.output, tt.modelName)
if tt.shouldTrim {
maxLen := tt.expectedLimit * charsPerToken
if len(result) > maxLen+50 { // +50 for the truncation message
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
}
if result == tt.output {
t.Error("expected output to be truncated but it wasn't")
}
} else {
if result != tt.output {
t.Error("expected output to not be truncated")
}
}
})
}
}

38
x/imagegen/.gitignore vendored
View File

@@ -1,38 +0,0 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

View File

@@ -1,151 +0,0 @@
# MLX engine
This is a small inference engine written in Go using [MLX](https://github.com/ml-explore/mlx), Apple's array framework for machine learning
## Goals
1. Implement multimodal runners: in a dedicated runner but eventually to be integrated into Ollama's primary runner.
2. Optimizing for image generation memory usage and output speed
3. (secondary): implement fast text model inference for gpt-oss, Llama.
## Prerequisites
**macOS:**
- macOS 14.0+ (Sonoma or later)
- Apple Silicon (M1/M2/M3)
- Xcode Command Line Tools
**Linux (building from source):**
- NVIDIA GPU (compute capability 7.0+)
- CUDA 12.0+ toolkit
- cuDNN
**Linux (prebuilt binaries):**
- NVIDIA GPU (compute capability 7.0+)
- NVIDIA driver 525+ (CUDA runtime libs are bundled)
**Both:**
- CMake 3.25+
- Go 1.21+
## Quick Start
### Build MLX
```bash
cmake -B build
cmake --build build --parallel
cmake --install build
```
This fetches MLX and mlx-c, builds them, and installs to `dist/`:
- `dist/lib/libmlxc.so` (or `.dylib`) - MLX C bindings
- `dist/lib/libmlx.a` - MLX static library
- `dist/include/` - Headers (mlx-c, CCCL for CUDA JIT)
To update MLX version, change `MLX_GIT_TAG` in `CMakeLists.txt` and rebuild.
### 2. Download a Model
Download Llama 3.1 8B (or any compatible model) in safetensors format:
```bash
mkdir -p ./weights
# Example using huggingface-cli
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
```
### 3. Run Inference
```bash
# Build
go build ./cmd/engine
# Text generation
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
# Qwen-Image 2512 (text-to-image)
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
-steps 8 -seed 42 -output edited.png
```
## Adding a Model
Use Claude Code with this repo. See `models/CLAUDE.md` for the full guide covering:
- Porting Python models to Go (forward pass, weight loading)
- Component testing with Python reference data
- Performance optimization
Reference implementations: `llama` (LLM), `qwen_image` (image generation), `qwen_image_edit` (image editing)
## Memory Management
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
Instead, arrays are automatically tracked and freed on `Eval()`:
```go
// All arrays are automatically tracked when created
x := mlx.Add(a, b)
y := mlx.Matmul(x, w)
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
mlx.Eval(y)
// After copying to CPU, free the array
data := y.Data()
y.Free()
```
Key points:
- All created arrays are automatically tracked
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
- Call `.Free()` when done with an array
## Testing
### Running Tests
```bash
# Run all tests (tests skip if dependencies missing)
go test ./...
# Run specific model tests
go test ./models/qwen_image/...
```
### Model Weights
Tests require model weights in `./weights/<model-name>/`:
```
weights/
├── Qwen-Image-2512/ # Qwen image generation
│ ├── text_encoder/
│ ├── transformer/
│ ├── vae/
│ └── tokenizer/
├── Llama-3.1-8B/ # LLM
└── ...
```
Download models using `huggingface-cli`:
```bash
hf download ./weights/Qwen-Image-2512 --local-dir ./weights/Qwen-Image-2512
```

View File

@@ -1,154 +0,0 @@
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,162 +0,0 @@
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}

View File

@@ -1,357 +0,0 @@
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}

View File

@@ -1,87 +0,0 @@
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -1,284 +0,0 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}

View File

@@ -1,47 +0,0 @@
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}

View File

@@ -1,46 +0,0 @@
# MLX Memory Management
| This package will get consolidated with `x/ml/backend/mlx` in the future.
## Automatic Tracking
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
### API
```go
result := mlx.Matmul(x, w) // arrays automatically tracked
mlx.Eval(result) // free non-kept, eval result (auto-kept)
```
### Key Functions
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
- `array.Free()` - mark array for cleanup on next Eval
### Loop Pattern
```go
for step := 0; step < maxTokens; step++ {
logits := model.Forward(token, caches)
oldToken := token
token = sample(logits)
// Keep cache state across iterations
for _, c := range caches {
mlx.Keep(c.State()...)
}
oldToken.Free() // mark for cleanup
mlx.AsyncEval(token) // frees old, evals new
}
```
### Notes
- `Eval()` and `AsyncEval()` auto-keep their outputs
- `Free()` marks for cleanup - actual free happens during next Eval
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
- Arrays created inside compiled closures are managed by MLX, not tracked

View File

@@ -1,171 +0,0 @@
package mlx
/*
#include "mlx/c/mlx.h"
#include <stdlib.h>
// Forward declaration for Go callback
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// Destructor for payload (Go handle)
extern void goClosureDestructor(void* payload);
*/
import "C"
import (
"runtime/cgo"
"sync"
"unsafe"
)
// inClosureCallback is set to true during closure callback execution.
var inClosureCallback bool
var closureCallbackMu sync.Mutex
// InClosureCallback returns true if we're currently executing inside a closure callback.
func InClosureCallback() bool {
closureCallbackMu.Lock()
defer closureCallbackMu.Unlock()
return inClosureCallback
}
// CompiledFunc is a compiled MLX function that can be called efficiently.
// All intermediate arrays during execution stay inside MLX - only inputs
// and outputs cross the Go boundary.
type CompiledFunc struct {
closure C.mlx_closure
compiled C.mlx_closure
}
// ClosureFunc is the signature for functions that can be compiled.
// It takes a slice of input arrays and returns a slice of output arrays.
type ClosureFunc func(inputs []*Array) []*Array
// Compile compiles a Go function into an optimized MLX closure.
// The function is traced once during compilation, then subsequent calls
// run the optimized graph without creating Go intermediate arrays.
//
// Example:
//
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
// a, b := inputs[0], inputs[1]
// c := mlx.Add(a, b)
// d := mlx.Mul(c, c)
// return []*mlx.Array{d}
// })
// defer compiled.Free()
//
// result := compiled.Call(x, y)[0]
func Compile(fn ClosureFunc) *CompiledFunc {
return CompileShapeless(fn, false)
}
// CompileShapeless compiles with optional shapeless mode.
// If shapeless=true, the function works for any input shape after tracing.
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
// Create a cgo.Handle to prevent the Go function from being GC'd
handle := cgo.NewHandle(fn)
// Create the closure from the Go callback
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureCallback),
unsafe.Pointer(handle),
(*[0]byte)(C.goClosureDestructor),
)
// Compile the closure
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &CompiledFunc{
closure: closure,
compiled: compiled,
}
}
// Call invokes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
// Pack inputs into vector
inputVec := C.mlx_vector_array_new()
for _, arr := range inputs {
C.mlx_vector_array_append_value(inputVec, arr.c)
}
// Apply compiled closure
outputVec := C.mlx_vector_array_new()
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
C.mlx_vector_array_free(inputVec)
// Unpack outputs
numOutputs := int(C.mlx_vector_array_size(outputVec))
outputs := make([]*Array, numOutputs)
for i := 0; i < numOutputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
outputs[i] = newArray(arr)
}
C.mlx_vector_array_free(outputVec)
return outputs
}
// CallEval invokes the compiled function and evaluates the results.
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
outputs := cf.Call(inputs...)
Eval(outputs...)
return outputs
}
// Free releases the compiled function resources.
func (cf *CompiledFunc) Free() {
C.mlx_closure_free(cf.compiled)
C.mlx_closure_free(cf.closure)
}
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
func borrowArray(array C.mlx_array) *Array {
return &Array{c: array}
}
//export goClosureCallback
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
// Set flag to disable AddCleanup during callback
closureCallbackMu.Lock()
inClosureCallback = true
closureCallbackMu.Unlock()
defer func() {
closureCallbackMu.Lock()
inClosureCallback = false
closureCallbackMu.Unlock()
}()
// Recover the Go function from the handle
handle := cgo.Handle(payload)
fn := handle.Value().(ClosureFunc)
// Convert input vector to Go slice - use borrowArray since MLX owns these
numInputs := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, numInputs)
for i := 0; i < numInputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, input, C.size_t(i))
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
}
// Call the Go function
outputs := fn(inputs)
// Build output vector
*res = C.mlx_vector_array_new()
for _, arr := range outputs {
C.mlx_vector_array_append_value(*res, arr.c)
}
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
handle := cgo.Handle(payload)
handle.Delete()
}

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,277 +0,0 @@
# Model Implementation Guide
See `README.md` for memory management (critical for Go + MLX).
## Phase 1: Import & Forward Pass
- Read Python reference implementation (PyTorch/Transformers)
- Create Go struct mirroring layer hierarchy
- Implement weight loading from safetensors (see `safetensors.go`)
- Port forward pass layer-by-layer, bottom-up
- For tokenizers: check if BPE (`bpe.go`) or custom needed
**Key files to reference:** `llama` (dense LLM), `gpt_oss` (MoE LLM), `zimage` (image generation), `qwen_image_edit` (image editing)
### Vision Models: Image Preprocessing
When implementing vision models (image-to-text, image editing, etc.), image preprocessing must match Python exactly. Common pitfalls:
1. **Resolution constraints**: Many vision models use `min_pixels` and `max_pixels` to constrain image size, not a fixed target area. Check the Python processor's `smart_resize` logic.
2. **Patch alignment**: Images must be resized to multiples of `factor = patch_size * spatial_merge_size` (e.g., 14 \* 2 = 28 for Qwen2.5-VL).
3. **Normalization**: Vision encoders use ImageNet stats (mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), not simple [-1, 1] scaling.
4. **Temporal dimension**: Video/image models may expect a temporal dimension (e.g., `[B, T, C, H, W]`). For single images, duplicate frames if `temporal_patch_size > 1`.
**Verification**: Always compare Go preprocessed image shape and statistics against Python to catch sizing mismatches early.
### Tokenizer & Chat Templates
Most instruction-tuned models require:
1. **BOS token**: Added at start of input (token ID 2 for most models)
2. **Chat template**: Wraps user prompt in model-specific format
**Common chat templates:**
| Model | Format |
| ------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| Llama 3 | `<\|begin_of_text\|><\|start_header_id\|>user<\|end_header_id\|>\n{prompt}<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>\n` |
| Gemma 3 | `<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n` |
| Qwen | `<\|im_start\|>user\n{prompt}<\|im_end\|>\n<\|im_start\|>assistant\n` |
**Checking tokenization:**
```bash
source .venv/bin/activate && python3 -c "
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('./weights/model-name')
tokens = tok.encode('Hello', add_special_tokens=True)
print('Tokens:', tokens)
print('Decoded:', [tok.decode([t]) for t in tokens])
"
```
### Text Model Checklist
Before moving to vision components, ensure the text model is fully working:
1. **Sliding window cache**: Some models (Gemma 3, GPT-OSS) use sliding window attention on certain layers. Use `cache.NewRotatingKVCache(windowSize)` for those layers, not `cache.NewKVCache()`. Check config for `sliding_window` and `sliding_window_pattern`.
2. **Unicode/UTF-8 decoding**: If output shows garbled characters like `Â` before spaces, the tokenizer's byte-level encoding isn't being decoded properly. Check `Decode()` handles UTF-8 byte sequences correctly.
3. **EOS tokens from vocabulary**: Don't hardcode EOS token IDs. The tokenizer should extract them from `added_tokens` in `tokenizer.json`. Multiple EOS tokens are common (e.g., Gemma has both `<eos>` and `<end_of_turn>`).
4. **Chat template**: Instruction-tuned models need chat formatting. Test with and without to ensure the model responds coherently.
5. **Compare with reference**: Always test against `mlx_lm.generate` with same prompt and `--temp 0` to verify outputs match.
## Phase 2: Correctness Testing
Run the model and look at the output. Make sure it outputs something coherent.
To compare correctness, add hooks to the python model and compare the output with debug statements in Go.
## Phase 3: Memory Verification
After loading, verify peak memory is close to final model size:
```bash
# Run and check peak vs active memory
/tmp/engine -model ./weights/MyModel -steps 1 2>&1 | grep -E "(peak|GB)"
```
**Expected:** Peak should be ~1.1x final size (small overhead is OK). If peak is 2-3x final size, you have a memory problem.
### Checking Weight Dtypes
```bash
# Check dtype of weights in safetensors files
python3 -c "
from safetensors import safe_open
f = safe_open('model.safetensors', 'pt')
for k in list(f.keys())[:5]:
print(k, f.get_tensor(k).dtype)
"
```
### f32 Weights Need Special Handling
If weights are f32 but model runs in bf16, use `GetTensorBF16()` instead of `GetTensor()`:
- `GetTensor()` uses MLX's native loader (loads all tensors from file at once)
- `GetTensorBF16()` loads one tensor at a time, converts to bf16, frees f32 immediately
This prevents peak memory from being 2x model size during loading.
## Phase 4: Performance
### Evaluation Strategy
- Call `mlx.Eval()` once per token/step, not inside loops
- Use `mlx.AsyncEval()` to pipeline: build next step's graph while current executes
- Never call `mlx.Eval()` inside attention or MLP - batch it at the end
### Fast Operations (Already Built-in)
These Go functions use MLX's fast fused kernels internally:
- `mlx.RMSNorm(x, weight, eps)` → uses `mlx_fast_rms_norm`
- `mlx.RoPE(x, dims, traditional, base, scale, offset)` → uses `mlx_fast_rope`
- `mlx.ScaledDotProductAttention(q, k, v, scale, causalMask)` → uses `mlx_fast_scaled_dot_product_attention`
### Type Promotion Gotchas
- `mlx.Mul(bf16Array, mlx.Full(shape, 2.0, mlx.Float32))` → upcasts everything to f32
- Use `mlx.MulScalar(bf16Array, 2.0)` to preserve dtype (if available), or ensure scalar arrays match input dtype
### Profiling
- Use `mactop` to check GPU utilization - should be ~100%
- If low, bottleneck is likely Go code (tokenization, data prep), not MLX
- Use `pprof` for CPU profiling to find Go-side overhead (CGO calls, tokenization, etc.)
- Use Metal debugger for kernel-level profiling (see docs/performance.md)
- Profile with `time.Since()` around major blocks
- Compare tok/s against reference (llama.cpp, MLX-LM)
## Phase 5: Polish
- Remove debug prints
- Add proper error handling
- Document config.json fields used
## Tips
- MLX is lazy; call `Eval()` only when you need values
- Check `model.safetensors.index.json` for weight→file mapping
## Common Gotchas
### MLX Transpose requires Contiguous
`mlx.Transpose()` returns a view with modified strides - calling `Data()` returns the original memory layout. Always follow with `mlx.Contiguous()` if you need correct data ordering:
```go
// Wrong - Data() returns original layout
x = mlx.Transpose(x, 0, 2, 3, 4, 1)
data := x.Data() // Bug: data is in wrong order
// Correct
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
data := x.Data() // Data is in transposed order
```
### Missing Biases in Weight Loading
Python layers often have optional biases. Check the safetensors files for bias tensors:
```bash
python3 -c "from safetensors import safe_open; f=safe_open('model.safetensors','pt'); print([k for k in f.keys() if 'bias' in k])"
```
### Don't Spam ClearCache() or Eval()
- `mlx.ClearCache()` clears the GPU cache but doesn't free arrays - it has minimal effect on memory. Don't call it repeatedly.
- `mlx.Eval()` forces synchronous evaluation and frees non-kept arrays. Call it once per step/token, not inside loops.
### Lazy Eval and Free() - The Critical Pattern
MLX arrays are lazy - operations build a graph, actual computation happens at `Eval()`. This has a critical implication for `Free()`:
```go
// WRONG: Lazy array references freed input
func BadForward(x *mlx.Array) *mlx.Array {
return mlx.Add(compute(x), x) // Returns lazy array referencing x
}
func Caller() {
result := BadForward(input)
input.Free() // Frees input, but result still references it!
mlx.Eval(result) // CRASH: "expected a non-empty mlx_array"
}
// CORRECT: Eval before caller can free inputs
func GoodForward(x *mlx.Array) *mlx.Array {
out := mlx.Add(compute(x), x)
mlx.Eval(out) // Materialize before returning
return out
}
```
**Rule**: If your function returns an array that references its input (residual connections, skip connections), you MUST `Eval()` before returning - otherwise the caller may free the input while the result still needs it.
**Debugging**: Errors like "expected a non-empty mlx_array" at Eval time often mean a tensor was freed while still referenced by a lazy graph. Add logging BEFORE the Free() calls to find which one, not inside the lazy operations.
### Data() and DataInt32() Trigger Eval
Calling `.Data()` or `.DataInt32()` on an array does an implicit `Eval()`, which frees any un-eval'd arrays:
```go
// WRONG: tokenArray gets freed when we eval image
tokenArray := mlx.NewArrayInt32(tokens, shape)
image := processImage(path) // This evals image internally
mlx.Eval(image) // This frees tokenArray!
tokenData := tokenArray.DataInt32() // CRASH: tokenArray was freed
// CORRECT: Eval arrays you need to keep before other evals
tokenArray := mlx.NewArrayInt32(tokens, shape)
mlx.Eval(tokenArray) // Materialize it first
image := processImage(path)
tokenData := tokenArray.DataInt32() // Works fine
```
**Rule**: Before calling any function that might do an `Eval()` internally, make sure to `Eval()` any arrays you'll need later. When passing arrays to model forward functions, eval them first if they were just created.
### Diffusers Pipeline vs Scheduler Defaults
Diffusers pipelines often pass custom parameters that override scheduler defaults. When writing tests, match what the **pipeline** does, not the raw scheduler:
```python
# Scheduler default (when no sigmas passed):
# sigmas from 1.0 to 1/1000 = 0.001
# But pipeline passes custom sigmas:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
scheduler.set_timesteps(sigmas=sigmas, ...) # 1.0 to 1/30 for 30 steps
```
Always check the pipeline source to see what parameters it passes to components.
### Diffusion Models: Timestep Scaling
Diffusion transformers use sinusoidal timestep embeddings with internal scaling. **Critical**: Check what the pipeline actually passes to the transformer, not just what the scheduler stores.
**Common pattern in diffusers (tricky!):**
- `scheduler.sigmas` = values in [0, 1] range (e.g., 1.0, 0.608, 0.02)
- `scheduler.timesteps` = sigmas × 1000 (e.g., 1000, 608, 20)
- **BUT** the pipeline often divides by 1000 before passing to transformer: `timestep=t / 1000`
- Transformer's `Timesteps` class has `scale=1000`, multiplying input by 1000
- Net effect: transformer receives sigma (0.608), scales to 608
**Verification - check the actual pipeline source:**
```bash
grep -A2 "timestep=" .venv/.../pipeline_*.py
# Look for: timestep=timestep / 1000 ← pipeline normalizes!
```
**Go approach (skip the multiply/divide dance):**
```go
// Store sigmas directly as timesteps - equivalent to Python's
// scheduler.timesteps / 1000 that the pipeline passes to transformer
s.Timesteps[i] = sigmas[i] // 0.608
// Transformer does: 0.608 * 1000 = 608 ✓
```
**Symptoms of wrong timestep scaling:**
- Noise predictions have wrong magnitude (off by orders of magnitude)
- Output images are completely noisy/corrupted or have extreme contrast
- Latents diverge from Python after first denoising step
**Key lesson:** Don't assume scheduler.timesteps is what the transformer receives - always check the pipeline's forward pass for any normalization.

View File

@@ -1,612 +0,0 @@
package gemma3
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextConfig holds configuration for the text model
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
// Computed fields
Scale float32 `json:"-"`
}
// TextModel is the Gemma 3 text-only model
type TextModel struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*DecoderLayer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
NormScaled *mlx.Array `weight:"-"`
tok *tokenizer.Tokenizer
*TextConfig
}
// DecoderLayer is a single transformer block
type DecoderLayer struct {
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
Attention *Attention
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
MLP *MLP
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
InputNormScaled *mlx.Array `weight:"-"`
PostAttnNormScaled *mlx.Array `weight:"-"`
PreFFNormScaled *mlx.Array `weight:"-"`
PostFFNormScaled *mlx.Array `weight:"-"`
// Whether this layer uses sliding window attention
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
QNormScaled *mlx.Array `weight:"-"`
KNormScaled *mlx.Array `weight:"-"`
}
// MLP is the feed-forward network with GELU activation
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
// LoadText loads the text-only Gemma 3 model
func LoadText(modelPath string) (*TextModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute scale
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Set defaults if not specified
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &TextModel{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
// Initialize layer metadata
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for output
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
precomputeGemmaScaledWeights(m)
return m, nil
}
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
// This avoids creating temporary arrays on every forward pass
func precomputeGemmaScaledWeights(m *TextModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
// Eval all the precomputed weights
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Eval(scaled...)
}
// isLayerSliding determines if a layer uses sliding window attention
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false // No sliding window
}
// Layer is full attention if (layerIdx + 1) % pattern == 0
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
// Get embeddings and scale by sqrt(hidden_size)
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
}
// Final norm and output projection
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
}
// Forward runs a decoder layer
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
// Pre-attention norm (use precomputed scaled weight)
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
// Attention
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
// Post-attention norm and residual
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
// Pre-FFN norm
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
// MLP
mlpOut := l.MLP.Forward(normed)
// Post-FFN norm and residual
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
// Forward runs attention with Q/K normalization
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization after reshaping (use precomputed scaled weight)
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// Apply RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// Repeat K/V for GQA if needed
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// Attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
var compiledGeluApprox *mlx.CompiledFunc
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
func getCompiledGeluApprox() *mlx.CompiledFunc {
if compiledGeluApprox == nil {
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApproxImpl(inputs[0])}
}, true)
}
return compiledGeluApprox
}
// Forward runs the MLP with GELU approximation (tanh variant)
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApproxImpl(x *mlx.Array) *mlx.Array {
// Constants
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
const coeff = 0.044715
// x^3
x3 := mlx.Mul(mlx.Mul(x, x), x)
// x + 0.044715 * x^3
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := mlx.MulScalar(inner, sqrt2OverPi)
// tanh(...)
tanh := mlx.Tanh(scaled)
// 1 + tanh(...)
onePlusTanh := mlx.AddScalar(tanh, 1.0)
// 0.5 * x * (1 + tanh(...))
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
}
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
// Gemma uses (1 + weight) instead of weight
scaledWeight := mlx.AddScalar(weight, 1.0)
return mlx.RMSNorm(x, scaledWeight, eps)
}
// Interface methods
func (m *TextModel) NumLayers() int { return len(m.Layers) }
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// FormatPrompt applies the Gemma 3 chat template to a prompt
func (m *TextModel) FormatPrompt(prompt string) string {
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
// Use rotating cache for sliding window layers
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
// Use regular cache for global attention layers
caches[i] = cache.NewKVCache()
}
}
return caches
}
// Config holds config for the full multimodal model
type Config struct {
TextConfig TextConfig `json:"text_config"`
VisionConfig VisionConfig `json:"vision_config"`
// Image token config (from config.json)
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
}
// Model is the full Gemma 3 multimodal model
type Model struct {
VisionTower *VisionTower `weight:"vision_tower"`
Projector *MultiModalProjector `weight:"multi_modal_projector"`
TextModel *TextModel `weight:"language_model"`
Config *Config
tok *tokenizer.Tokenizer
}
// Load loads the full multimodal Gemma 3 model
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Set defaults for text config (multimodal config often has incomplete text_config)
// These defaults match transformers.Gemma3TextConfig defaults
tc := &cfg.TextConfig
if tc.HeadDim == 0 {
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
}
if tc.NumAttentionHeads == 0 {
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
tc.NumAttentionHeads = 8
}
if tc.NumKeyValueHeads == 0 {
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
tc.NumKeyValueHeads = 4
}
if tc.VocabSize == 0 {
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
}
if tc.RopeTheta == 0 {
tc.RopeTheta = 1000000
}
if tc.RopeLocalBaseFreq == 0 {
tc.RopeLocalBaseFreq = 10000
}
if tc.RMSNormEps == 0 {
tc.RMSNormEps = 1e-6
}
if tc.SlidingWindowPattern == 0 {
tc.SlidingWindowPattern = 6
}
if tc.MaxPositionEmbeddings == 0 {
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
}
// Compute text model scale
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
// Set defaults for image token config
if cfg.BOITokenIndex == 0 {
cfg.BOITokenIndex = 255999 // <start_of_image>
}
if cfg.EOITokenIndex == 0 {
cfg.EOITokenIndex = 256000 // <end_of_image>
}
if cfg.ImageTokenIndex == 0 {
cfg.ImageTokenIndex = 262144 // <image_soft_token>
}
if cfg.MMTokensPerImage == 0 {
cfg.MMTokensPerImage = 256
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
VisionTower: &VisionTower{
Embeddings: &VisionEmbeddings{},
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
Config: &cfg.VisionConfig,
},
Projector: &MultiModalProjector{},
TextModel: &TextModel{
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
TextConfig: &cfg.TextConfig,
},
Config: &cfg,
tok: tok,
}
// Initialize text layer metadata
for i := range m.TextModel.Layers {
m.TextModel.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
}
}
// Initialize vision encoder layers
for i := range m.VisionTower.Encoder {
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for text output
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
m.TextModel.tok = tok
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeGemmaScaledWeights(m.TextModel)
// Precompute projector's scaled weight
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
mlx.Eval(m.Projector.SoftEmbNormScaled)
return m, nil
}
// Forward runs the text-only forward pass
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.TextModel.Forward(tokens, caches)
}
// ForwardWithImage runs the multimodal forward pass
// tokens: [B, L] input token IDs (with image placeholder tokens)
// image: [B, H, W, C] preprocessed image tensor
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
cfg := m.Config.TextConfig
// Find image token position FIRST before any eval that might free tokens
imageStartPos := int32(-1)
if image != nil && B == 1 {
tokenData := tokens.DataInt32() // This evals tokens
for i, t := range tokenData {
if t == m.Config.ImageTokenIndex {
imageStartPos = int32(i)
break
}
}
}
// Get text embeddings and scale
h := m.TextModel.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
// Process image if provided
if image != nil && imageStartPos >= 0 {
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
visionFeatures := m.VisionTower.Forward(image)
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
// Eval h and imageEmbeds together so neither gets freed
mlx.Eval(h, imageEmbeds)
// Cast imageEmbeds to match text embeddings dtype (bf16)
if imageEmbeds.Dtype() != h.Dtype() {
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
mlx.Eval(imageEmbeds)
}
// Insert image embeddings at the known position
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
}
// Run through text model layers
for i, layer := range m.TextModel.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
}
// Final norm and output projection
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
}
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
// at a known position (to avoid re-scanning tokens after eval)
// textEmbeds: [B, L, hidden_size] text embeddings
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
// startPos: starting position of image tokens in the sequence
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
numImageTokens := imageEmbeds.Shape()[1]
L := textEmbeds.Shape()[1]
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
afterStart := startPos + numImageTokens
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
// Concatenate: before + imageEmbeds + after along axis 1
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
}
// Interface methods for Model
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
// FormatPrompt applies the Gemma 3 multimodal chat template
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
func (m *Model) FormatPromptWithImage(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
// Input tokens containing boi_token (255999) are expanded to:
// boi_token + 256 * image_token + eoi_token
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
for _, t := range tokens {
if t == m.Config.BOITokenIndex {
// Expand: boi + 256 * image_token + eoi
result = append(result, m.Config.BOITokenIndex)
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
result = append(result, m.Config.ImageTokenIndex)
}
result = append(result, m.Config.EOITokenIndex)
} else {
result = append(result, t)
}
}
return result
}

View File

@@ -1,56 +0,0 @@
package gemma3
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
)
// ProcessImage loads and preprocesses an image for the vision tower
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return ProcessImageData(img, imageSize)
}
// ProcessImageData preprocesses an image.Image for the vision tower
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
// Resize to target size using bilinear interpolation
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
// Convert to float32 array [H, W, C] and normalize
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
data := make([]float32, imageSize*imageSize*3)
idx := 0
for y := int32(0); y < imageSize; y++ {
for x := int32(0); x < imageSize; x++ {
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
// RGBA returns 16-bit values, convert to 8-bit
data[idx] = float32(r>>8)/127.5 - 1.0
data[idx+1] = float32(g>>8)/127.5 - 1.0
data[idx+2] = float32(b>>8)/127.5 - 1.0
idx += 3
}
}
// Create MLX array [1, H, W, C] for NHWC layout
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
mlx.Eval(arr) // Materialize to prevent use-after-free
return arr, nil
}

View File

@@ -1,48 +0,0 @@
package gemma3
import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// MultiModalProjector projects vision features to text embedding space
type MultiModalProjector struct {
// mm_input_projection_weight: [vision_hidden, text_hidden]
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
SoftEmbNormScaled *mlx.Array `weight:"-"`
}
// Forward projects vision features to text space
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
B := visionFeatures.Shape()[0]
visionHidden := visionFeatures.Shape()[2]
// Reshape to [B, 64, 64, hidden]
gridSize := int32(64) // sqrt(4096)
pooledSize := int32(16) // 64/4
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
// Average over pooling dimensions (axes 2 and 4)
h = mlx.Mean(h, 4, false)
h = mlx.Mean(h, 2, false)
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
numTokens := pooledSize * pooledSize
h = mlx.Reshape(h, B, numTokens, visionHidden)
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
return mlx.Linear(h, p.InputProjection)
}

View File

@@ -1,136 +0,0 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
HiddenSize int32 `json:"hidden_size"`
ImageSize int32 `json:"image_size"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
PatchSize int32 `json:"patch_size"`
}
// VisionTower is the SigLIP vision encoder
type VisionTower struct {
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
Config *VisionConfig
}
// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
PosEmbed *nn.Embedding `weight:"position_embedding"`
}
// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
Attention *VisionAttention `weight:"self_attn"`
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
MLP *VisionMLP `weight:"mlp"`
}
// VisionAttention implements multi-head self-attention
type VisionAttention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OutProj *nn.Linear `weight:"out_proj"`
}
// VisionMLP is the feed-forward network
type VisionMLP struct {
FC1 *nn.Linear `weight:"fc1"`
FC2 *nn.Linear `weight:"fc2"`
}
// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
h = mlx.Add(h, bias)
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
B := h.Shape()[0]
gridH, gridW := h.Shape()[1], h.Shape()[2]
hidden := h.Shape()[3]
numPatches := gridH * gridW
h = mlx.Reshape(h, B, numPatches, hidden)
// Add position embeddings
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
h = mlx.Add(h, posEmbed)
// Encoder layers
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
for _, layer := range v.Encoder {
h = layer.Forward(h, v.Config, scale)
}
// Final layer norm
h = v.PostLayerNorm.Forward(h)
return h
}
// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
// Pre-norm attention
h := l.LayerNorm1.Forward(x)
h = l.Attention.Forward(h, cfg, scale)
x = mlx.Add(x, h)
// Pre-norm MLP
h = l.LayerNorm2.Forward(x)
h = l.MLP.Forward(h)
return mlx.Add(x, h)
}
// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
B, L := x.Shape()[0], x.Shape()[1]
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
// Scaled dot-product attention (no causal mask for vision)
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
return a.OutProj.Forward(out)
}
// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
h := mlx.GELU(m.FC1.Forward(x))
return m.FC2.Forward(h)
}

View File

@@ -1,485 +0,0 @@
package gpt_oss
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds YaRN or other RoPE scaling configuration
type RopeScaling struct {
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
}
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
SlidingWindow int32 `json:"sliding_window"`
NumLocalExperts int32 `json:"num_local_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
LayerTypes []string `json:"layer_types"`
SwiGLULimit float32 `json:"swiglu_limit"`
RopeScaling *RopeScaling `json:"rope_scaling"`
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
YarnFreqs *mlx.Array // computed
YarnMscale float32
}
// swiGLU applies the GPT-OSS custom SwiGLU activation.
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
// with clipping: gate to [None, limit], up to [-limit, limit]
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
// Clip gate to [None, limit]
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
// Clip up to [-limit, limit]
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
// glu_scaled = alpha * gate_clipped
gluScaled := mlx.MulScalar(gateClipped, alpha)
// sig = sigmoid(glu_scaled)
sig := mlx.Sigmoid(gluScaled)
// out_glu = gate_clipped * sig
outGlu := mlx.Mul(gateClipped, sig)
// result = out_glu * (up_clipped + 1)
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
}
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
var compiledSwiGLU *mlx.CompiledFunc
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
const alpha float32 = 1.702
const limit float32 = 7.0
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
}, true) // shapeless=true so it works for any input size
}
return compiledSwiGLU
}
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
// Based on mlx-lm's YarnRoPE implementation
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
// yarn_find_correction_dim
yarnFindCorrectionDim := func(numRotations float64) float64 {
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
}
// yarn_find_correction_range
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
if low < 0 {
low = 0
}
if high > int(dims)-1 {
high = int(dims) - 1
}
// yarn_get_mscale
yarnGetMscale := func(scale, mscale float64) float64 {
if scale <= 1 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
// Compute frequencies
// freq_extra = base ** (arange(0, dims, 2) / dims)
// freq_inter = scaling_factor * freq_extra
halfDims := dims / 2
freqData := make([]float32, halfDims)
for i := int32(0); i < halfDims; i++ {
exp := float64(2*i) / float64(dims)
freqExtra := math.Pow(float64(base), exp)
freqInter := float64(scalingFactor) * freqExtra
// linear ramp mask
var freqMask float64
if low == high {
freqMask = 0.0
} else {
t := (float64(i) - float64(low)) / float64(high-low)
if t < 0 {
t = 0
}
if t > 1 {
t = 1
}
freqMask = 1.0 - t
}
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
}
return mlx.NewArray(freqData, []int32{halfDims}), mscale
}
// initYarn initializes YaRN RoPE if configured
func (a *Attention) initYarn(cfg *Config) {
a.YarnMscale = 1.0
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
cfg.HeadDim,
cfg.RopeTheta,
cfg.RopeScaling.Factor,
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
cfg.RopeScaling.BetaFast,
cfg.RopeScaling.BetaSlow,
)
}
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
offset := 0
if c != nil {
offset = c.Offset()
}
if a.YarnFreqs != nil {
if a.YarnMscale != 1.0 {
q = mlx.MulScalar(q, a.YarnMscale)
}
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
} else {
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v, int(L))
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// CreateSlidingWindowMask creates a causal mask with sliding window
// Mirrors mlx-lm's create_causal_mask with window_size
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
// Build mask aligned to actual cache length (may be rotated)
// rinds covers existing keys: [keyStart, keyStart+keyLen)
// linds covers new queries: [queryStart, queryStart+seqLen)
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
return mlx.LogicalAnd(causalMask, windowMask)
}
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
type MoE struct {
Router *nn.Linear `weight:"mlp.router"`
TopK int32
HiddenSize int32
GroupSize int
Bits int
// Expert weights (loaded manually via sanitizeExpertWeights)
GateBlocks, GateScales, GateBias *mlx.Array
UpBlocks, UpScales, UpBias *mlx.Array
DownBlocks, DownScales, DownBias *mlx.Array
}
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
logits := moe.Router.Forward(x)
neg := mlx.Neg(logits)
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
weights := mlx.Softmax(topKVal, -1)
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
doSort := B*L >= 64
var invOrder *mlx.Array
sorted := false
n := B * L * moe.TopK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
sorted = true
}
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.GateBias != nil {
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
}
if moe.UpBias != nil {
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
}
hidden := getCompiledSwiGLU().Call(gate, up)[0]
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.DownBias != nil {
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
}
type Block struct {
Attention *Attention
MLP *MoE
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
LayerType string // "sliding_attention" or "full_attention"
}
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
x := m.EmbedTokens.Forward(tokens)
// Find representative cache indices for sliding window attention
var swaIdx int = -1
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" {
swaIdx = i
break
}
}
// Create masks once at model level
var fullMask, swaMask *mlx.Array
var fullMaskMode, swaMaskMode string
if L > 1 {
fullMaskMode = "causal"
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
c := caches[swaIdx]
offset := c.Offset()
windowSize := int(m.SlidingWindow)
cacheLen := min(int(L), windowSize)
if offset > 0 {
cacheLen = min(c.Len()+int(L), windowSize)
}
if int(L) > windowSize {
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
} else {
swaMaskMode = "causal"
}
} else {
swaMaskMode = "causal"
}
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
mask, maskMode := fullMask, fullMaskMode
if layer.LayerType == "sliding_attention" {
mask, maskMode = swaMask, swaMaskMode
}
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
}
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
}
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
if gateUpBlocks != nil {
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
s := gub.Shape()
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpScales != nil {
s := gateUpScales.Shape()
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpBias != nil {
s := gateUpBias.Shape()
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
}
if downBlocks != nil {
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
}
return moe
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load simple weights via struct tags
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers with custom MoE handling
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
layer := &Block{}
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d: %w", i, err)
}
// Initialize attention YaRN
layer.Attention.initYarn(&cfg)
// Load MoE with weight sanitization
moe := sanitizeExpertWeights(weights, prefix)
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
moe.TopK = cfg.NumExpertsPerTok
moe.HiddenSize = cfg.HiddenSize
layer.MLP = moe
// Set layer type
layer.LayerType = "full_attention"
if int(i) < len(cfg.LayerTypes) {
layer.LayerType = cfg.LayerTypes[i]
}
m.Layers[i] = layer
}
// Release safetensors BEFORE eval - lazy arrays have captured data,
// this reduces peak memory by freeing mmap during materialization
weights.ReleaseAll()
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
func (m *Model) MaxContextLength() int32 {
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
return m.RopeScaling.OriginalMaxPositionEmbeddings
}
return 131072
}

View File

@@ -1,150 +0,0 @@
package llama
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Layer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"lm_head,optional"`
tok *tokenizer.Tokenizer
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
}
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Config)
}
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k, v = c.Update(k, v, int(L))
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}
// Interface methods
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}

View File

@@ -1,64 +0,0 @@
package qwen_image
import (
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,348 +0,0 @@
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,216 +0,0 @@
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,133 +0,0 @@
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -1,172 +0,0 @@
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

View File

@@ -1,866 +0,0 @@
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Qwen-Image transformer configuration
type TransformerConfig struct {
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
NHeads int32 `json:"num_attention_heads"` // 24
HeadDim int32 `json:"attention_head_dim"` // 128
NLayers int32 `json:"num_layers"` // 60
InChannels int32 `json:"in_channels"` // 64
OutChannels int32 `json:"out_channels"` // 16
PatchSize int32 `json:"patch_size"` // 2
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
NormEps float32 `json:"norm_eps"` // 1e-6
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
GuidanceEmbeds bool `json:"guidance_embeds"` // false
}
// defaultTransformerConfig returns config for Qwen-Image transformer
func defaultTransformerConfig() *TransformerConfig {
return &TransformerConfig{
HiddenDim: 3072, // 24 * 128
NHeads: 24,
HeadDim: 128,
NLayers: 60,
InChannels: 64,
OutChannels: 16,
PatchSize: 2,
JointAttentionDim: 3584,
NormEps: 1e-6,
AxesDimsRope: []int32{16, 56, 56},
GuidanceEmbeds: false,
}
}
// TimestepEmbedder creates timestep embeddings
type TimestepEmbedder struct {
Linear1Weight *mlx.Array // [256, hidden_dim]
Linear1Bias *mlx.Array
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
Linear2Bias *mlx.Array
}
// newTimestepEmbedder creates a timestep embedder from weights
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
if err != nil {
return nil, err
}
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
if err != nil {
return nil, err
}
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
if err != nil {
return nil, err
}
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
if err != nil {
return nil, err
}
return &TimestepEmbedder{
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
Linear1Bias: linear1Bias,
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
Linear2Bias: linear2Bias,
}, nil
}
// Forward computes timestep embeddings
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
half := int32(128) // embedding_dim / 2
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
tExpanded := mlx.ExpandDims(t, 1)
args := mlx.Mul(tExpanded, freqsArr)
args = mlx.MulScalar(args, 1000.0) // scale
// [cos, sin] (flip_sin_to_cos=True)
sinArgs := mlx.Sin(args)
cosArgs := mlx.Cos(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
// MLP: linear1 -> silu -> linear2
h := mlx.Linear(embedding, te.Linear1Weight)
h = mlx.Add(h, te.Linear1Bias)
h = mlx.SiLU(h)
h = mlx.Linear(h, te.Linear2Weight)
h = mlx.Add(h, te.Linear2Bias)
return h
}
// JointAttention implements dual-stream joint attention
type JointAttention struct {
// Image projections
ToQ *mlx.Array
ToQB *mlx.Array
ToK *mlx.Array
ToKB *mlx.Array
ToV *mlx.Array
ToVB *mlx.Array
ToOut *mlx.Array
ToOutB *mlx.Array
NormQ *mlx.Array
NormK *mlx.Array
// Text (added) projections
AddQProj *mlx.Array
AddQProjB *mlx.Array
AddKProj *mlx.Array
AddKProjB *mlx.Array
AddVProj *mlx.Array
AddVProjB *mlx.Array
ToAddOut *mlx.Array
ToAddOutB *mlx.Array
NormAddQ *mlx.Array
NormAddK *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newJointAttention creates a joint attention layer
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
return &JointAttention{
ToQ: mlx.Transpose(toQ, 1, 0),
ToQB: toQB,
ToK: mlx.Transpose(toK, 1, 0),
ToKB: toKB,
ToV: mlx.Transpose(toV, 1, 0),
ToVB: toVB,
ToOut: mlx.Transpose(toOut, 1, 0),
ToOutB: toOutB,
NormQ: normQ,
NormK: normK,
AddQProj: mlx.Transpose(addQProj, 1, 0),
AddQProjB: addQProjB,
AddKProj: mlx.Transpose(addKProj, 1, 0),
AddKProjB: addKProjB,
AddVProj: mlx.Transpose(addVProj, 1, 0),
AddVProjB: addVProjB,
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
ToAddOutB: toAddOutB,
NormAddQ: normAddQ,
NormAddK: normAddK,
NHeads: cfg.NHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
}, nil
}
// Forward computes joint attention
// img: [B, L_img, D], txt: [B, L_txt, D]
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
D := imgShape[2]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// === Image Q/K/V ===
imgFlat := mlx.Reshape(img, B*Limg, D)
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
// QK norm (RMSNorm per head)
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
// Apply RoPE
if imgFreqs != nil {
qImg = applyRoPE(qImg, imgFreqs)
kImg = applyRoPE(kImg, imgFreqs)
}
// === Text Q/K/V ===
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
if txtFreqs != nil {
qTxt = applyRoPE(qTxt, txtFreqs)
kTxt = applyRoPE(kTxt, txtFreqs)
}
// Concatenate for joint attention: [txt, img] order
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
// Transpose to [B, nheads, L, head_dim]
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
// SDPA
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
// Transpose back and split
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
// Output projections
outImg = mlx.Reshape(outImg, B*Limg, D)
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
outImg = mlx.Reshape(outImg, B, Limg, D)
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
return outImg, outTxt
}
// applyRoPE applies rotary embeddings using complex multiplication
// x: [B, L, nheads, head_dim]
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
halfDim := headDim / 2
// Reshape x to pairs: [B, L, nheads, half, 2]
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
// Extract real/imag parts
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
xReal = mlx.Squeeze(xReal, 4)
xImag = mlx.Squeeze(xImag, 4)
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
freqReal = mlx.Squeeze(freqReal, 4)
freqImag = mlx.Squeeze(freqImag, 4)
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
// Interleave back
outReal = mlx.ExpandDims(outReal, 4)
outImag = mlx.ExpandDims(outImag, 4)
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
return mlx.Reshape(out, B, L, nheads, headDim)
}
// MLP implements GELU MLP (not GEGLU)
type MLP struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
OutWeight *mlx.Array
OutBias *mlx.Array
}
// newMLP creates a GELU MLP
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
outWeight, _ := weights.Get(prefix + ".net.2.weight")
outBias, _ := weights.Get(prefix + ".net.2.bias")
return &MLP{
ProjWeight: mlx.Transpose(projWeight, 1, 0),
ProjBias: projBias,
OutWeight: mlx.Transpose(outWeight, 1, 0),
OutBias: outBias,
}, nil
}
// Forward applies GELU MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
h = geluApprox(h)
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
}
// geluApprox implements approximate GELU
func geluApprox(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
inner = mlx.MulScalar(inner, sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// TransformerBlock is a single dual-stream transformer block
type TransformerBlock struct {
Attention *JointAttention
ImgMLP *MLP
TxtMLP *MLP
ImgModWeight *mlx.Array
ImgModBias *mlx.Array
TxtModWeight *mlx.Array
TxtModBias *mlx.Array
HiddenDim int32
NormEps float32
}
// newTransformerBlock creates a transformer block
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
attn, err := newJointAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
return &TransformerBlock{
Attention: attn,
ImgMLP: imgMLP,
TxtMLP: txtMLP,
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
ImgModBias: imgModBias,
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
TxtModBias: txtModBias,
HiddenDim: cfg.HiddenDim,
NormEps: cfg.NormEps,
}, nil
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
siluT := mlx.SiLU(temb)
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
imgModParts := splitMod6(imgMod, tb.HiddenDim)
txtModParts := splitMod6(txtMod, tb.HiddenDim)
// Pre-attention: norm + modulate
imgNorm := layerNormNoAffine(img, tb.NormEps)
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
txtNorm := layerNormNoAffine(txt, tb.NormEps)
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
// Joint attention
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
// Pre-MLP: norm + modulate
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
// MLP
mlpImg := tb.ImgMLP.Forward(imgNorm2)
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
return img, txt
}
// splitMod6 splits modulation into 6 parts each [B, 1, D]
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
shape := mod.Shape()
B := shape[0]
parts := make([]*mlx.Array, 6)
for i := int32(0); i < 6; i++ {
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
parts[i] = mlx.ExpandDims(part, 1)
}
return parts
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Qwen-Image transformer model
type Transformer struct {
Config *TransformerConfig
ImgIn *mlx.Array
ImgInBias *mlx.Array
TxtIn *mlx.Array
TxtInBias *mlx.Array
TxtNorm *mlx.Array
TEmbed *TimestepEmbedder
Layers []*TransformerBlock
NormOutWeight *mlx.Array
NormOutBias *mlx.Array
ProjOut *mlx.Array
ProjOutBias *mlx.Array
}
// Load loads the transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Qwen-Image transformer...")
cfg := defaultTransformerConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading input projections... ")
imgIn, _ := weights.Get("img_in.weight")
imgInBias, _ := weights.Get("img_in.bias")
txtIn, _ := weights.Get("txt_in.weight")
txtInBias, _ := weights.Get("txt_in.bias")
txtNorm, _ := weights.Get("txt_norm.weight")
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
m.ImgInBias = imgInBias
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
m.TxtInBias = txtInBias
m.TxtNorm = txtNorm
fmt.Println("✓")
fmt.Print(" Loading timestep embedder... ")
m.TEmbed, err = newTimestepEmbedder(weights)
if err != nil {
return fmt.Errorf("timestep embedder: %w", err)
}
fmt.Println("✓")
m.Layers = make([]*TransformerBlock, cfg.NLayers)
for i := int32(0); i < cfg.NLayers; i++ {
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
prefix := fmt.Sprintf("transformer_blocks.%d", i)
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
if err != nil {
return fmt.Errorf("layer %d: %w", i, err)
}
}
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
fmt.Print(" Loading output layers... ")
normOutWeight, _ := weights.Get("norm_out.linear.weight")
normOutBias, _ := weights.Get("norm_out.linear.bias")
projOut, _ := weights.Get("proj_out.weight")
projOutBias, _ := weights.Get("proj_out.bias")
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
m.NormOutBias = normOutBias
m.ProjOut = mlx.Transpose(projOut, 1, 0)
m.ProjOutBias = projOutBias
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadFromPath is a convenience function to load transformer from path
func LoadTransformerFromPath(path string) (*Transformer, error) {
m := &Transformer{}
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
return nil, err
}
return m, nil
}
// Forward runs the transformer
// img: [B, L_img, in_channels] patchified latents
// txt: [B, L_txt, joint_attention_dim] text embeddings
// t: [B] timesteps (0-1)
// imgFreqs, txtFreqs: RoPE frequencies
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
for _, layer := range tr.Layers {
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
}
// Final norm with modulation (AdaLayerNormContinuous)
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// ForwardWithCache runs the transformer with layer caching for speedup.
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between denoising steps, so we cache their
// outputs and reuse them on non-refresh steps.
//
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
// step: current denoising step (0-indexed)
// cacheInterval: refresh cache every N steps (e.g., 3)
// cacheLayers: number of shallow layers to cache (e.g., 15)
func (tr *Transformer) ForwardWithCache(
img, txt, t *mlx.Array,
imgFreqs, txtFreqs *mlx.Array,
stepCache *cache.StepCache,
step, cacheInterval, cacheLayers int,
) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
// Check if we should refresh the cache
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
for i, layer := range tr.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached outputs for shallow layers
imgH = stepCache.Get(i)
txtH = stepCache.Get2(i)
} else {
// Compute layer
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
// Cache shallow layers on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, imgH)
stepCache.Set2(i, txtH)
}
}
}
// Final norm with modulation (AdaLayerNormContinuous)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// RoPECache holds precomputed RoPE frequencies
type RoPECache struct {
ImgFreqs *mlx.Array // [L_img, head_dim]
TxtFreqs *mlx.Array // [L_txt, head_dim]
}
// PrepareRoPE computes RoPE for image and text sequences
// This matches Python's QwenEmbedRope with scale_rope=True
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := ComputeAxisFreqs(axesDims[0], theta)
freqsH := ComputeAxisFreqs(axesDims[1], theta)
freqsW := ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
// Image frequencies with scale_rope=True
imgLen := imgH * imgW
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
imgFreqsData := make([]float32, imgLen*headDim)
hHalf := imgH / 2
wHalf := imgW / 2
idx := int32(0)
for y := int32(0); y < imgH; y++ {
for x := int32(0); x < imgW; x++ {
// Frame = 0
for i := 0; i < len(freqsT)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
}
idx += int32(len(freqsT) * 2)
// Height: scale_rope pattern
hNegCount := imgH - hHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
}
}
idx += int32(len(freqsH) * 2)
// Width: scale_rope pattern
wNegCount := imgW - wHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
}
}
idx += int32(len(freqsW) * 2)
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies
maxVidIdx := max(hHalf, wHalf)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
halfDim := dim / 2
freqs := make([]float64, halfDim)
for i := int32(0); i < halfDim; i++ {
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
}
return freqs
}
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
table := make([][]float32, maxIdx)
for idx := int32(0); idx < maxIdx; idx++ {
var pos float64
if negative {
pos = float64(-maxIdx + int32(idx))
} else {
pos = float64(idx)
}
row := make([]float32, len(baseFreqs)*2)
for i, f := range baseFreqs {
angle := pos * f
row[i*2] = float32(math.Cos(angle))
row[i*2+1] = float32(math.Sin(angle))
}
table[idx] = row
}
return table
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// -> [B, pH, pW, C, 2, 2]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// -> [B, pH*pW, C*4]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
channels := shape[2] / (patchSize * patchSize)
pH := H / patchSize
pW := W / patchSize
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// -> [B, C, pH, 2, pW, 2]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// -> [B, C, H, W]
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
// Add temporal dimension for VAE: [B, C, 1, H, W]
return mlx.ExpandDims(x, 2)
}

View File

@@ -1,117 +0,0 @@
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestTransformerConfig tests configuration invariants.
func TestTransformerConfig(t *testing.T) {
cfg := defaultTransformerConfig()
// Property: hidden_dim = n_heads * head_dim
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
}
// Property: axes_dims_rope sums to head_dim
var ropeSum int32
for _, d := range cfg.AxesDimsRope {
ropeSum += d
}
if ropeSum != cfg.HeadDim {
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
}
// Property: in_channels = out_channels * patch_size^2
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
if cfg.InChannels != expectedIn {
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
}
}
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
func TestTransformerRoPE(t *testing.T) {
cfg := defaultTransformerConfig()
// Test with small image dimensions
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
txtLen := int32(5)
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Verify shapes: [seq_len, head_dim]
imgSeqLen := imgH * imgW
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
}
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
}
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
}
// Verify values are finite
imgData := ropeCache.ImgFreqs.Data()
for i := 0; i < min(100, len(imgData)); i++ {
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
break
}
}
}
// TestTransformerForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestTransformerForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
transformer := &Transformer{}
if err := transformer.Load(weightsPath); err != nil {
t.Fatalf("Failed to load transformer: %v", err)
}
mlx.Keep(mlx.Collect(transformer)...)
cfg := transformer.Config
// Small test inputs
batchSize := int32(1)
imgH, imgW := int32(4), int32(4)
imgSeqLen := imgH * imgW
txtSeqLen := int32(5)
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
// Forward pass
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(out)
// Verify output shape: [batch, img_seq_len, in_channels]
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
gotShape := out.Shape()
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
}
// Verify output is finite
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,852 +0,0 @@
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
// Input is channels-last: [B, T, H, W, C]
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// RMSNorm3D applies RMS normalization over channels
// Works with channels-last [B, T, H, W, C] format
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
// Use h as working variable, keep x intact for residual (caller will free x)
// Conv handles its own pools, so we just need pools for non-conv operations
var h *mlx.Array
// Keep x so it survives Eval() cleanup - needed for residual connection
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1 (handles its own pools)
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2 (handles its own pools)
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection (shortcut handles its own pools if present)
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V: [B*T, H*W, 3*C]
// Weight is [outC, inC] or [outC, inC, 1, 1]
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0) // [inC, outC]
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// out: [B*T, 1, H*W, C]
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0) // [inC, outC]
out = mlx.Linear(out, p) // [B*T, H*W, C]
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block with staged memory management
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
// ResBlocks handle their own pools
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Upsampler handles its own pools
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
// Uses staged pools to reduce peak memory during 2x upsampling
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block of decoder
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
// Each component handles its own pools; we just free inputs
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading Qwen-Image VAE decoder...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading post_quant_conv... ")
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
m.PostQuantConv = postQuantConv
fmt.Println("✓")
fmt.Print(" Loading conv_in... ")
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
m.ConvIn = convIn
fmt.Println("✓")
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
fmt.Print(" Loading mid_block... ")
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
m.MidBlock = midBlock
fmt.Println("✓")
// Up blocks (reversed dim_mult)
fmt.Print(" Loading up_blocks... ")
numUpBlocks := len(cfg.DimMult)
m.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
m.UpBlocks[i] = upBlock
}
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
fmt.Print(" Loading output layers... ")
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
m.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
m.ConvOut = convOut
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
m := &VAEDecoder{}
if err := m.Load(filepath.Join(path, "vae")); err != nil {
return nil, err
}
return m, nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] normalized latents
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Stage 1a: Denormalize and transpose
{
z = vae.Denormalize(z)
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// Stage 1b: PostQuantConv (handles its own pools)
x = vae.PostQuantConv.Forward(z)
z.Free()
// Stage 1c: ConvIn (handles its own pools)
{
prev := x
x = vae.ConvIn.Forward(x)
prev.Free()
}
// Stage 2: Mid block (handles its own pools)
x = vae.MidBlock.Forward(x)
// Stage 3: Up blocks (each handles its own pools)
for _, upBlock := range vae.UpBlocks {
x = upBlock.Forward(x)
}
// Stage 4a: NormOut + silu
{
prev := x
x = vae.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Stage 4b: ConvOut (handles its own pools)
{
prev := x
x = vae.ConvOut.Forward(x)
prev.Free()
}
// Stage 4c: Post-processing
{
prev := x
// Clamp to [-1, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
// Convert back from channels-last to channels-first
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// Denormalize reverses the normalization applied during encoding
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
mean = mlx.ToBFloat16(mean)
std = mlx.ToBFloat16(std)
return mlx.Add(mlx.Mul(z, std), mean)
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
}
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[2]
W := shape[3]
x = mlx.Transpose(x, 0, 2, 3, 1)
x = mlx.Reshape(x, B*H*W, shape[1])
wShape := weight.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(weight, wShape[0], wShape[1])
} else {
w = weight
}
w = mlx.Transpose(w, 1, 0)
out := mlx.Linear(x, w)
outC := w.Dim(1)
out = mlx.Reshape(out, B, H, W, outC)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
x = pad2D(x, 1, 1, 1, 1)
return conv2D(x, weight, 1, 1)
}
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
x = mlx.Transpose(x, 0, 2, 3, 1)
w = mlx.Transpose(w, 0, 2, 3, 1)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := w.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := extractPatches2D(x, kH, kW, strideH, strideW)
wFlat := mlx.Reshape(w, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
out = mlx.Reshape(out, B, outH, outW, Cout)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * strideH
startW := j * strideW
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[2]
W := shape[3]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 2)
x = mlx.Take(x, colIdx, 3)
return x
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create repeat indices for rows
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
// Create repeat indices for columns
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
// Take along H (axis 1) then W (axis 2)
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
// weight: [outC, kH, kW, inC] (MLX channels-last format)
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
// stride=1, padding=0 (we already padded manually)
return mlx.Conv2d(x, weight, 1, 0)
}

View File

@@ -1,112 +0,0 @@
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestVAEConfig tests configuration invariants.
func TestVAEConfig(t *testing.T) {
cfg := defaultVAEConfig()
// Property: latents_mean and latents_std have z_dim elements
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
}
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
}
// Property: dim_mult defines 4 stages
if len(cfg.DimMult) != 4 {
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
}
// Property: temperal_downsample has 3 elements (for 3 transitions)
if len(cfg.TemperalDownsample) != 3 {
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
}
}
// TestVAELatentsNormalization tests the latent denormalization values.
func TestVAELatentsNormalization(t *testing.T) {
cfg := defaultVAEConfig()
// Verify latents_std values are all positive
for i, std := range cfg.LatentsStd {
if std <= 0 {
t.Errorf("latents_std[%d] should be positive: %v", i, std)
}
}
// Verify values are in reasonable range (from actual model)
for i, mean := range cfg.LatentsMean {
if math.Abs(float64(mean)) > 5 {
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
}
}
for i, std := range cfg.LatentsStd {
if std > 10 {
t.Errorf("latents_std[%d] seems too large: %v", i, std)
}
}
}
// TestVAEDecoderForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestVAEDecoderForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/vae"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
vae := &VAEDecoder{}
if err := vae.Load(weightsPath); err != nil {
t.Fatalf("Failed to load VAE decoder: %v", err)
}
mlx.Keep(mlx.Collect(vae)...)
// Small test input: [B, C, T, H, W]
// After 4 upsampling stages (2x each), H/W multiply by 16
batchSize := int32(1)
channels := int32(16)
frames := int32(1)
latentH := int32(4)
latentW := int32(4)
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
// Decode
out := vae.Decode(latents)
mlx.Eval(out)
// Verify output shape: [B, 3, T, H*16, W*16]
outShape := out.Shape()
if outShape[0] != batchSize {
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
}
if outShape[1] != 3 {
t.Errorf("channels: got %d, want 3", outShape[1])
}
if outShape[2] != frames {
t.Errorf("frames: got %d, want %d", outShape[2], frames)
}
expectedH := latentH * 16 // 4 stages of 2x upsampling
expectedW := latentW * 16
if outShape[3] != expectedH || outShape[4] != expectedW {
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
outShape[3], outShape[4], expectedH, expectedW)
}
// Verify output is in valid range (should be clamped to [0, 1] by decode)
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,680 +0,0 @@
package qwen_image_edit
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution (or 2D if weight is 4D)
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape()
// Handle both 5D (3D conv) and 4D (2D conv) weights
if len(shape) == 4 {
// 2D conv: [O, I, kH, kW] - need to apply per-frame
return c.forward2D(x)
}
// 3D conv: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := c.Weight.Shape() // [O, I, kH, kW]
kernelH := wShape[2]
kernelW := wShape[3]
outC := wShape[0]
padH := kernelH / 2
padW := kernelW / 2
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad spatially
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
// Apply 2D conv
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = mlx.Conv2d(x, weight, 1, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Get output spatial dims
outH := H
outW := W
// Reshape back to [B, T, H, W, C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// RMSNorm3D applies RMS normalization over channels
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0)
qkv := mlx.Linear(x, w)
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0)
out = mlx.Linear(out, p)
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
return mlx.Conv2d(x, weight, 1, 0)
}
// conv2DStrided applies conv with stride > 1 using manual patch extraction
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := weight.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := extractPatches2DStrided(x, kH, kW, stride)
wFlat := mlx.Reshape(weight, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outH, outW, Cout)
}
// conv3DStrided applies 3D conv with strides using manual patch extraction
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
// strideT, strideH, strideW are the strides for each dimension
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
wShape := weight.Shape()
Cout := wShape[0]
// I := wShape[1]
kT := wShape[2]
kH := wShape[3]
kW := wShape[4]
// For temporal: if T < kT, we need to repeat frames temporally
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
// Python Qwen2.5-VL duplicates the frame, not zero-pads
if T < kT {
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
T = kT
}
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
// Extract 3D patches in [C, T, H, W] order to match Python
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outT, outH, outW, Cout)
}
// extractPatches3DStrided extracts 3D patches with given strides
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
numPatches := outT * outH * outW
patches := make([]*mlx.Array, numPatches)
idx := 0
for t := int32(0); t < outT; t++ {
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startT := t * strideT
startH := i * strideH
startW := j * strideW
// Extract patch: [B, kT, kH, kW, C]
patch := mlx.Slice(x,
[]int32{0, startT, startH, startW, 0},
[]int32{B, startT + kT, startH + kH, startW + kW, C})
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
// Flatten to [B, C*T*H*W]
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
patches[idx] = patch
idx++
}
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
}
// extractPatches2DStrided extracts patches with given stride
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * stride
startW := j * stride
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}

View File

@@ -1,473 +0,0 @@
package qwen_image_edit
import (
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
_ "golang.org/x/image/webp"
)
// loadImageFile loads an image from disk
func loadImageFile(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return img, nil
}
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
pixels := make([]float32, width*height*3)
idx := 0
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
r, g, b, _ := img.At(x, y).RGBA()
pixels[idx] = float32(r) / 65535.0
pixels[idx+1] = float32(g) / 65535.0
pixels[idx+2] = float32(b) / 65535.0
idx += 3
}
}
return pixels
}
// normalizeImageNet applies ImageNet normalization to an image tensor
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
return mlx.Div(mlx.Sub(arr, mean), std)
}
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch dimension [1, C, H, W]
arr = mlx.ExpandDims(arr, 0)
// Convert to bf16
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// clampFloat clamps a value to [0, 255] and returns uint8
func clampFloat(v, weightSum float64) uint8 {
v /= weightSum
if v < 0 {
v = 0
}
if v > 255 {
v = 255
}
return uint8(math.Round(v))
}
// ImageDims holds dimensions for a preprocessed image
type ImageDims struct {
// Original image dimensions
OrigW, OrigH int32
// Condition image dimensions (for vision encoder)
CondW, CondH int32
// VAE image dimensions
VaeW, VaeH int32
// Latent dimensions (VAE dims / vae_scale_factor)
LatentW, LatentH int32
// Patch dimensions (latent dims / patch_size)
PatchW, PatchH int32
}
// ProcessorConfig holds image processor configuration
type ProcessorConfig struct {
// Condition image size (target pixel area for vision encoder input)
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
// Pipeline resizes image to this area before passing to encode_prompt
ConditionImageSize int32
// VAE image size (target pixel area)
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
VAEImageSize int32
// Image normalization (ImageNet stats for vision encoder)
ImageMean []float32
ImageStd []float32
}
// defaultProcessorConfig returns default processor config
func defaultProcessorConfig() *ProcessorConfig {
return &ProcessorConfig{
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// Processor handles image preprocessing for Qwen-Image-Edit
type Processor struct {
Config *ProcessorConfig
}
// Load loads the processor config
func (p *Processor) Load(path string) error {
p.Config = defaultProcessorConfig()
return nil
}
// LoadAndPreprocess loads an image and preprocesses it for both paths
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, err
}
bounds := img.Bounds()
origW := bounds.Dx()
origH := bounds.Dy()
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Preprocess for condition (vision encoder) - two-step resize
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
return condImage, vaeImage, nil
}
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
// Used by edit_plus pipeline for multi-image input
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImage converts image to tensor for vision encoder
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
resized := resizeImageBicubic(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
if normalize {
arr = p.normalizeImageNet(arr)
}
return prepareImageTensor(arr)
}
// preprocessImageForVAE converts image to tensor for VAE encoding
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
// Scale to [-1, 1]: arr * 2 - 1
arr = mlx.MulScalar(arr, 2.0)
arr = mlx.AddScalar(arr, -1.0)
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch and temporal dimensions [1, C, 1, H, W]
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// smartResize implements Python Qwen2VL processor's smart_resize logic
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
// Round to factor
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
// Ensure minimum factor size
if hBar < factor {
hBar = factor
}
if wBar < factor {
wBar = factor
}
// Check pixel constraints
total := hBar * wBar
if total > maxPixels {
// Scale down
beta := math.Sqrt(float64(maxPixels) / float64(total))
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
} else if total < minPixels {
// Scale up
beta := math.Sqrt(float64(minPixels) / float64(total))
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
// calculateDimensions calculates width and height for a target area while maintaining ratio
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
width := math.Sqrt(float64(targetArea) * ratio)
height := width / ratio
m := float64(multiple)
width = math.Round(width/m) * m
height = math.Round(height/m) * m
// Ensure minimum dimensions
if width < m {
width = m
}
if height < m {
height = m
}
return int32(width), int32(height)
}
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
func resizeImageLanczos(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
dst := image.NewRGBA(image.Rect(0, 0, width, height))
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
lanczos3 := &draw.Kernel{
Support: 3.0,
At: func(t float64) float64 {
if t == 0 {
return 1.0
}
if t < 0 {
t = -t
}
if t >= 3.0 {
return 0.0
}
// sinc(t) * sinc(t/3)
piT := math.Pi * t
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
},
}
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
return dst
}
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
// Uses separable interpolation with PIL's coordinate mapping for exact match
func resizeImageBicubic(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
srcW := bounds.Dx()
srcH := bounds.Dy()
// Convert to RGBA if needed
var src *image.RGBA
if rgba, ok := img.(*image.RGBA); ok {
src = rgba
} else {
src = image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
src.Set(x, y, img.At(x, y))
}
}
}
// Keys cubic with a=-0.5 (PIL BICUBIC)
cubic := func(x float64) float64 {
if x < 0 {
x = -x
}
if x < 1 {
return 1.5*x*x*x - 2.5*x*x + 1
}
if x < 2 {
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
}
return 0
}
// Horizontal pass: srcW -> width, keep srcH rows
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
for y := 0; y < srcH; y++ {
for dstX := 0; dstX < width; dstX++ {
// PIL coordinate mapping: center-to-center
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
baseX := int(math.Floor(srcXf))
var sumR, sumG, sumB, sumA, weightSum float64
for i := -1; i <= 2; i++ {
sx := baseX + i
if sx < 0 {
sx = 0
}
if sx >= srcW {
sx = srcW - 1
}
w := cubic(math.Abs(srcXf - float64(baseX+i)))
c := src.RGBAAt(sx, y)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
temp.SetRGBA(dstX, y, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
// Vertical pass: srcH -> height
dst := image.NewRGBA(image.Rect(0, 0, width, height))
for x := 0; x < width; x++ {
for dstY := 0; dstY < height; dstY++ {
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
baseY := int(math.Floor(srcYf))
var sumR, sumG, sumB, sumA, weightSum float64
for j := -1; j <= 2; j++ {
sy := baseY + j
if sy < 0 {
sy = 0
}
if sy >= srcH {
sy = srcH - 1
}
w := cubic(math.Abs(srcYf - float64(baseY+j)))
c := temp.RGBAAt(x, sy)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
dst.SetRGBA(x, dstY, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
return dst
}
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
const vaeScaleFactor int32 = 8
const patchSize int32 = 2
condImages := make([]*mlx.Array, len(imagePaths))
vaeImages := make([]*mlx.Array, len(imagePaths))
dims := make([]ImageDims, len(imagePaths))
for i, imagePath := range imagePaths {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
}
bounds := img.Bounds()
origW := int32(bounds.Dx())
origH := int32(bounds.Dy())
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Calculate derived dimensions
latentW := vaeW / vaeScaleFactor
latentH := vaeH / vaeScaleFactor
patchW := latentW / patchSize
patchH := latentH / patchSize
dims[i] = ImageDims{
OrigW: origW,
OrigH: origH,
CondW: condW,
CondH: condH,
VaeW: vaeW,
VaeH: vaeH,
LatentW: latentW,
LatentH: latentH,
PatchW: patchW,
PatchH: patchH,
}
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
}
return condImages, vaeImages, dims, nil
}

View File

@@ -1,608 +0,0 @@
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
// It reuses components from qwen_image where possible.
package qwen_image_edit
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
Processor *Processor // Image processor for vision encoder
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
VAE *VAE // Combined encoder + decoder
}
// Load loads the Qwen-Image-Edit model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image-Edit model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer from processor directory
fmt.Print(" Loading tokenizer... ")
processorPath := filepath.Join(modelPath, "processor")
tok, err := tokenizer.Load(processorPath)
if err != nil {
// Fallback to tokenizer directory
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err = tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
}
m.Tokenizer = tok
fmt.Println("✓")
// Load processor (image preprocessing config)
fmt.Print(" Loading processor... ")
m.Processor = &Processor{}
if err := m.Processor.Load(processorPath); err != nil {
return fmt.Errorf("processor: %w", err)
}
fmt.Println("✓")
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
m.TextEncoder = &qwen_image.Qwen25VL{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer (reuse qwen_image)
m.Transformer = &qwen_image.Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE (encoder + decoder)
m.VAE = &VAE{}
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE: %w", err)
}
mlx.Eval(mlx.Collect(m.VAE)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Edit edits an image based on a text prompt.
// inputImagePath: path to input image
// prompt: text description of desired edit
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// EditFromConfig edits images using the unified config struct.
// Accepts one or more input images.
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
if len(inputImagePaths) == 0 {
return nil, fmt.Errorf("no input images provided")
}
start := time.Now()
result, err := m.edit(inputImagePaths, cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// EditImage implements model.ImageEditModel interface.
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
}
// EditMultiImage edits using multiple source images.
// This matches diffusers' QwenImageEditPlusPipeline behavior.
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
return m.EditFromConfig(inputImagePaths, cfg)
}
// edit is the internal editing pipeline that handles one or more images.
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// Load and preprocess all input images
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
if err != nil {
return nil, fmt.Errorf("preprocess images: %w", err)
}
for _, img := range condImages {
mlx.Keep(img)
}
for _, img := range vaeImages {
mlx.Keep(img)
}
mlx.Eval(append(condImages, vaeImages...)...)
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
vaeScaleFactor := int32(8)
// Output dimensions - if not specified, use first input image dimensions
if cfg.Width <= 0 {
cfg.Width = inputDims[0].VaeW
}
if cfg.Height <= 0 {
cfg.Height = inputDims[0].VaeH
}
// Output (noise) latent dimensions
outLatentH := cfg.Height / vaeScaleFactor
outLatentW := cfg.Width / vaeScaleFactor
outPH := outLatentH / tcfg.PatchSize
outPW := outLatentW / tcfg.PatchSize
noiseSeqLen := outPH * outPW
imgSeqLen := noiseSeqLen
// Encode prompt with all images for conditioning
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding prompt: %w", err)
}
mlx.Keep(posEmb)
mlx.Eval(posEmb)
var negEmb *mlx.Array
if useCFG {
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding negative prompt: %w", err)
}
mlx.Keep(negEmb)
mlx.Eval(negEmb)
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
for i, vaeImage := range vaeImages {
imageLatents := m.VAE.Encode(vaeImage)
imageLatents = m.VAE.Normalize(imageLatents)
imageLatents2D := mlx.Squeeze(imageLatents, 2)
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
mlx.Keep(packed)
mlx.Eval(packed)
allImageLatentsPacked[i] = packed
}
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
mlx.Keep(imageLatentsPacked)
mlx.Eval(imageLatentsPacked)
// Scheduler
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
// Init noise latents in packed format
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
mlx.Eval(latents)
// RoPE cache
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Denoising loop
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
mlx.Eval(timestep)
latents2D := mlx.Squeeze(latents, 2)
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
}
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
}
// Free denoising temporaries
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()
// Decode latents
decoded := m.decodeAndPostprocess(latents)
latents.Free()
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
// This prevents CFG from inflating magnitude too much.
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
// Upcast to float32 for precision
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
// CFG: pred = neg + scale * (pos - neg)
diff := mlx.Sub(posF32, negF32)
scaledDiff := mlx.MulScalar(diff, scale)
combPred := mlx.Add(negF32, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
mlx.Eval(output)
return mlx.ToBFloat16(output)
}
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
latents = m.VAE.Denormalize(latents)
decoded := m.VAE.Decode(latents)
// Post-process: squeeze temporal dim and rescale to [0, 1]
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
mlx.Eval(decoded)
return decoded
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
// Handles single or multiple input images with different resolutions.
//
// Parameters:
// - outPH, outPW: output patch dimensions (noise latent resolution)
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
// - txtLen: text sequence length
// - axesDims: RoPE axis dimensions [16, 56, 56]
//
// Returns RoPE cache where:
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
// - Following positions are for each input image (interpolated from output res)
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
// Helper to compute RoPE for a single position at output resolution with scale_rope
computePosFreqs := func(framePos, y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame position
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = posFreqsT[framePos][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Helper to compute RoPE for frame -1 (used for last condition image)
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
computeNegFrameFreqs := func(y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame -1: use last row of negative frame frequencies
negFrameIdx := maxIdx - 1
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = negFreqsT[negFrameIdx][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Total image sequence length: noise + all input images
noiseSeqLen := outPH * outPW
totalImgLen := noiseSeqLen
for _, dims := range inputDims {
totalImgLen += dims.PatchH * dims.PatchW
}
imgFreqsData := make([]float32, totalImgLen*headDim)
idx := int32(0)
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
for y := int32(0); y < outPH; y++ {
for x := int32(0); x < outPW; x++ {
row := computePosFreqs(0, y, x)
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
// For multiple images: Python uses frame -1 for the LAST condition image
// (_compute_condition_freqs), positive indices for others.
numImages := len(inputDims)
lastImgIdx := numImages - 1
for imgIdx, dims := range inputDims {
inPH := dims.PatchH
inPW := dims.PatchW
// Determine frame index for this image
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
// Map each input position to an output position using linear interpolation
for y := int32(0); y < inPH; y++ {
for x := int32(0); x < inPW; x++ {
// Interpolate: map input (y, x) to output grid position
// This is the key fix from DiffSynth's forward_sampling
var yOut, xOut int32
if inPH == 1 {
yOut = 0
} else {
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
yOut = y * (outPH - 1) / (inPH - 1)
}
if inPW == 1 {
xOut = 0
} else {
xOut = x * (outPW - 1) / (inPW - 1)
}
var row []float32
if useNegFrame {
// Last image in multi-image uses frame -1
row = computeNegFrameFreqs(yOut, xOut)
} else {
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
frameIdx := int32(imgIdx + 1)
row = computePosFreqs(frameIdx, yOut, xOut)
}
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies - start after max video index
maxVidIdx := max(outPH/2, outPW/2)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &qwen_image.RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}

View File

@@ -1,225 +0,0 @@
package qwen_image_edit
import (
"math"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
)
// TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000)
// Expected values from Python:
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
expectedFreqsT := []float64{
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
}
expectedFreqsH_first4 := []float64{
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
}
expectedFreqsH_last4 := []float64{
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
}
// Test temporal frequencies (dim=16)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
if len(freqsT) != 8 {
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
}
for i, expected := range expectedFreqsT {
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
}
}
// Test height/width frequencies (dim=56)
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
if len(freqsH) != 28 {
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
}
for i, expected := range expectedFreqsH_first4 {
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
}
}
for i, expected := range expectedFreqsH_last4 {
idx := 24 + i // last 4 of 28
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
}
}
}
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
func TestMakeFreqTable(t *testing.T) {
theta := float64(10000)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
maxIdx := int32(4096)
// Test positive table
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
// Position 0 should give cos=1, sin=0 for all frequencies
for i := 0; i < len(freqsT)*2; i += 2 {
if posTable[0][i] != 1.0 {
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
}
if posTable[0][i+1] != 0.0 {
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
}
}
// Position 1, first frequency (1.0): angle = 1*1 = 1
// cos(1) = 0.5403, sin(1) = 0.8415
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
}
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
}
// Test negative table
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
// negTable[4095] corresponds to position -1
// cos(-1) = cos(1), sin(-1) = -sin(1)
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
}
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
}
// negTable[4094] corresponds to position -2
// cos(-2) = cos(2), sin(-2) = -sin(2)
cos2 := math.Cos(2.0)
sin2 := math.Sin(2.0)
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
}
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
}
}
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
func TestPrepareRoPE_QwenImage(t *testing.T) {
if !mlx.GPUIsAvailable() {
t.Skip("GPU not available")
}
mlx.SetDefaultDeviceCPU()
// 4x4 patch grid, single image
imgH, imgW := int32(4), int32(4)
txtLen := int32(5)
axesDims := []int32{16, 56, 56}
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
// Check shapes
imgShape := cache.ImgFreqs.Shape()
if imgShape[0] != 16 { // 4*4 patches
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
}
// For single image (frame=0), all temporal values should be cos=1, sin=0
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
mlx.Eval(imgFreqsCPU)
imgData := imgFreqsCPU.Data()
// Check first 16 values of patch 0 (temporal cos/sin pairs)
for i := 0; i < 16; i += 2 {
cosVal := imgData[i]
sinVal := imgData[i+1]
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
}
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
}
}
cache.ImgFreqs.Free()
cache.TxtFreqs.Free()
}
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
func TestScaleRopePositions(t *testing.T) {
// For a 4x4 grid with scale_rope=True:
// hHalf = 2, wHalf = 2
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
//
// Height positions:
// y=0: -(4-2) + 0 = -2
// y=1: -(4-2) + 1 = -1
// y=2: 2 - 2 = 0
// y=3: 3 - 2 = 1
//
// Same for width
pH, pW := int32(4), int32(4)
hHalf := pH / 2
wHalf := pW / 2
hNegCount := pH - hHalf
wNegCount := pW - wHalf
expectedH := []int32{-2, -1, 0, 1}
expectedW := []int32{-2, -1, 0, 1}
for y := int32(0); y < pH; y++ {
var hPos int32
if y < hNegCount {
hPos = -(pH - hHalf) + y
} else {
hPos = y - hNegCount
}
if hPos != expectedH[y] {
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
}
}
for x := int32(0); x < pW; x++ {
var wPos int32
if x < wNegCount {
wPos = -(pW - wHalf) + x
} else {
wPos = x - wNegCount
}
if wPos != expectedW[x] {
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
}
}
}
// TestRoPEHeadDimensions verifies the head dimension breakdown
func TestRoPEHeadDimensions(t *testing.T) {
// axes_dims_rope = [16, 56, 56]
// Each dimension uses half the values for frequencies
// So we get: 8 + 28 + 28 = 64 frequency values
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
axesDims := []int32{16, 56, 56}
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
expectedHeadDim := expectedFreqs * 2
if expectedFreqs != 64 {
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
}
if expectedHeadDim != 128 {
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
}
// This should match the transformer's attention head dimension
// hidden_size = 3072, num_heads = 24
// head_dim = 3072 / 24 = 128
}

View File

@@ -1,640 +0,0 @@
package qwen_image_edit
import (
"fmt"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// VAE is the full VAE with encoder and decoder
type VAE struct {
Config *VAEConfig
Encoder *VAEEncoder
Decoder *VAEDecoder
}
// Load loads the VAE from a directory
func (m *VAE) Load(path string) error {
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load weights as f32 for quality (matches Python default behavior)
// VAE decoder precision is critical for final image quality
fmt.Print(" Loading weights as f32... ")
if err := weights.Load(mlx.DtypeFloat32); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
// Load encoder
fmt.Print(" Loading encoder... ")
m.Encoder = &VAEEncoder{}
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
fmt.Println("✓")
// Load decoder
fmt.Print(" Loading decoder... ")
m.Decoder = &VAEDecoder{}
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("decoder: %w", err)
}
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor in [-1, 1] range
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
return m.Encoder.Encode(x)
}
// Decode decodes latents to image
// z: [B, C, T, H, W] latents (denormalized)
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
return m.Decoder.Decode(z)
}
// Normalize applies latent normalization
// Input z should be f32 (from VAE encoder), output is f32 for transformer
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
// Mean/std are f32, will match z dtype through broadcasting
return mlx.Div(mlx.Sub(z, mean), std)
}
// Denormalize reverses latent normalization
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
// Convert latents to f32 for VAE decoder quality
z = mlx.AsType(z, mlx.DtypeFloat32)
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
return mlx.Add(mlx.Mul(z, std), mean)
}
// VAEEncoder is the encoder part of the VAE
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
// - Blocks 0,1: ResBlocks (base_dim)
// - Block 2: Downsample
// - Blocks 3,4: ResBlocks (base_dim*2)
// - Block 5: Downsample + temporal
// - Blocks 6,7: ResBlocks (base_dim*4)
// - Block 8: Downsample + temporal
// - Blocks 9,10: ResBlocks (base_dim*4)
type VAEEncoder struct {
Config *VAEConfig
ConvIn *CausalConv3d
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
MidBlock *MidBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
QuantConv *CausalConv3d
}
// EncoderBlock is either a ResBlock or a Downsample
type EncoderBlock interface {
Forward(x *mlx.Array) *mlx.Array
IsDownsample() bool
}
// EncoderResBlock wraps ResBlock
type EncoderResBlock struct {
*ResBlock
}
func (b *EncoderResBlock) IsDownsample() bool { return false }
// EncoderDownsample is a downsample layer
type EncoderDownsample struct {
Resample *CausalConv3d
TimeConv *CausalConv3d // Optional temporal downsample
}
func (d *EncoderDownsample) IsDownsample() bool { return true }
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
// Spatial downsample with stride 2
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
x = d.forwardSpatialDownsample(x)
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
// The Python forward checks: if feat_cache is not None ... then use time_conv
// Since we don't support streaming, we skip time_conv entirely.
return x
}
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := d.Resample.Weight.Shape()
outC := wShape[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
// Apply 2D conv with stride 2
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = conv2DStrided(x, weight, 2)
if d.Resample.Bias != nil {
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Output dims after stride 2: (H+1)/2, (W+1)/2
outH := (H + 1) / 2
outW := (W + 1) / 2
// Reshape back to [B, T, H', W', C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// loadFromWeights loads the encoder from pre-loaded weights
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
e.Config = cfg
// Conv in
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
if err != nil {
return err
}
e.ConvIn = convIn
// Encoder uses flat block structure:
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
e.Blocks = make([]EncoderBlock, 0, 11)
// Track dimensions
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
blockIdx := 0
for stage := 0; stage < len(cfg.DimMult); stage++ {
inDim := cfg.BaseDim
if stage > 0 {
inDim = dims[stage-1]
}
outDim := dims[stage]
// ResBlocks for this stage (num_res_blocks per stage)
for r := int32(0); r < cfg.NumResBlocks; r++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
currentInDim := inDim
if r > 0 {
currentInDim = outDim
}
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
if err != nil {
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, block)
blockIdx++
}
// Downsample after each stage except the last
if stage < len(cfg.DimMult)-1 {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
if err != nil {
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, down)
blockIdx++
}
}
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
if err != nil {
return err
}
e.MidBlock = midBlock
// Norm out
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
if err != nil {
return err
}
e.NormOut = normOut
// Conv out
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
if err != nil {
return err
}
e.ConvOut = convOut
// Quant conv
quantConv, err := newCausalConv3d(weights, "quant_conv")
if err != nil {
return err
}
e.QuantConv = quantConv
return nil
}
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
block, err := newResBlock(weights, prefix, inDim, outDim)
if err != nil {
return nil, err
}
return &EncoderResBlock{block}, nil
}
// newEncoderDownsample creates a downsample layer for the encoder
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
resample, err := newCausalConv3d(weights, prefix+".resample.1")
if err != nil {
return nil, err
}
var timeConv *CausalConv3d
if temporal {
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
}
return &EncoderDownsample{
Resample: resample,
TimeConv: timeConv,
}, nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor (channels-first)
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
mlx.Eval(x)
// Conv in
x = e.ConvIn.Forward(x)
// Encoder blocks (mix of ResBlocks and Downsamplers)
for _, block := range e.Blocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Mid block
x = e.MidBlock.Forward(x)
// Norm + silu
{
prev := x
x = e.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Conv out
{
prev := x
x = e.ConvOut.Forward(x)
prev.Free()
}
// Quant conv
{
prev := x
x = e.QuantConv.Forward(x)
prev.Free()
}
// Get mode from distribution (first half of channels = mean)
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
shape := x.Shape()
latentC := shape[4] / 2
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
// Convert back to channels-first [N, C, T, H, W]
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
mlx.Eval(x)
return x
}
// VAEDecoder is the decoder part of the VAE
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// loadFromWeights loads the decoder from pre-loaded weights
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
d.Config = cfg
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
d.PostQuantConv = postQuantConv
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
d.ConvIn = convIn
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
d.MidBlock = midBlock
// Up blocks (reversed dim_mult)
numUpBlocks := len(cfg.DimMult)
d.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
d.UpBlocks[i] = upBlock
}
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
d.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
d.ConvOut = convOut
return nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] denormalized latents
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Convert from channels-first to channels-last
{
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// PostQuantConv
x = d.PostQuantConv.Forward(z)
z.Free()
// ConvIn
{
prev := x
x = d.ConvIn.Forward(x)
prev.Free()
}
// Mid block
x = d.MidBlock.Forward(x)
// Up blocks
for _, upBlock := range d.UpBlocks {
x = upBlock.Forward(x)
}
// NormOut + silu
{
prev := x
x = d.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// ConvOut
{
prev := x
x = d.ConvOut.Forward(x)
prev.Free()
}
// Post-processing: clamp and convert back to channels-first
{
prev := x
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// DownBlock handles downsampling in encoder
type DownBlock struct {
ResBlocks []*ResBlock
Downsampler *Downsample
}
// newDownBlock creates a down block
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var downsampler *Downsample
if downsampleMode != "" {
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
}
return &DownBlock{
ResBlocks: resBlocks,
Downsampler: downsampler,
}, nil
}
// Forward applies down block
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range d.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if d.Downsampler != nil {
prev := x
x = d.Downsampler.Forward(x)
prev.Free()
}
return x
}
// Downsample handles spatial downsampling
type Downsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newDownsample creates a downsampler
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Downsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies downsampling to channels-last input [B, T, H, W, C]
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := d.Conv.Shape()[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
// For 3x3 stride 2: pad 1 on all sides
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv with stride 2 using manual strided patching
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
x = conv2DStrided(x, weight, 2)
if d.Bias != nil {
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
mlx.Eval(x)
return x
}

View File

@@ -1,146 +0,0 @@
package zimage
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
}
// DefaultFlowMatchSchedulerConfig returns default config
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0,
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
}
}
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
// This is used in Z-Image-Turbo for fast sampling
type FlowMatchEulerScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Discretized timesteps
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchEulerScheduler creates a new scheduler
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
return &FlowMatchEulerScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
s.Timesteps = make([]float32, numSteps+1)
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i <= numSteps; i++ {
t := 1.0 - float32(i)/float32(numSteps)
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
t = s.timeShift(mu, t)
}
s.Timesteps[i] = t
s.Sigmas[i] = t
}
}
// timeShift applies the dynamic time shift (match Python)
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
// exp(mu) / (exp(mu) + (1/t - 1))
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity/noise from the model
// timestepIdx: current timestep index
// sample: current noisy sample
// Returns: denoised sample for next step
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
// where v_t is the velocity predicted by the model
dt := sigmaNext - sigma // This is negative (going from noise to clean)
// x_next = x + dt * velocity
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// ScaleSample scales the sample for model input (identity for flow matching)
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
// Flow matching doesn't need scaling
return sample
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// GetTimesteps returns all timesteps (implements Scheduler interface)
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
return s.Timesteps
}
// AddNoise adds noise to clean samples for a given timestep
// Used for img2img or inpainting
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-t) * x_0 + t * noise
t := s.Timesteps[timestepIdx]
oneMinusT := 1.0 - t
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
scaledNoise := mlx.MulScalar(noise, t)
return mlx.Add(scaledClean, scaledNoise)
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return RandomNormal(shape, seed)
}
// RandomNormal creates a random normal tensor using MLX
func RandomNormal(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
// Latent is 8x smaller than image (VAE downscale)
latentH := height / 8
latentW := width / 8
return []int32{batchSize, latentChannels, latentH, latentW}
}

View File

@@ -1,294 +0,0 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Qwen3Config holds Qwen3 text encoder configuration
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// loadQwen3Config loads text encoder config from a JSON file
func loadQwen3Config(path string) (*Qwen3Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
}
// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
Attention *Qwen3Attention `weight:"self_attn"`
MLP *Qwen3MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Qwen3Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Qwen3Config
}
// Load loads the Qwen3 text encoder from a directory
func (m *Qwen3TextEncoder) Load(path string) error {
fmt.Println("Loading Qwen3 text encoder...")
// Load config
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = cfg
// Pre-allocate layers slice
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights via struct tags... ")
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
fmt.Println("✓")
// Initialize computed fields
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
weights.ReleaseAll()
return nil
}
// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string {
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
embeddings := te.Forward(tokensArr)
return embeddings, maskArr
}

View File

@@ -1,690 +0,0 @@
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Z-Image transformer configuration
type TransformerConfig struct {
Dim int32 `json:"dim"`
NHeads int32 `json:"n_heads"`
NKVHeads int32 `json:"n_kv_heads"`
NLayers int32 `json:"n_layers"`
NRefinerLayers int32 `json:"n_refiner_layers"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"-"` // Computed from AllPatchSize
CapFeatDim int32 `json:"cap_feat_dim"`
NormEps float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
TScale float32 `json:"t_scale"`
QKNorm bool `json:"qk_norm"`
AxesDims []int32 `json:"axes_dims"`
AxesLens []int32 `json:"axes_lens"`
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
}
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// t: [B] timesteps
// Create sinusoidal embedding
half := te.FreqEmbedSize / 2
// freqs = exp(-log(10000) * arange(half) / half)
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// t[:, None] * freqs[None, :] -> [B, half]
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
args := mlx.Mul(tExpanded, freqsArr)
// embedding = [cos(args), sin(args)] -> [B, 256]
cosArgs := mlx.Cos(args)
sinArgs := mlx.Sin(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
// MLP: linear1 -> silu -> linear2
h := te.Linear1.Forward(embedding)
h = mlx.SiLU(h)
h = te.Linear2.Forward(h)
return h
}
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
}
// Forward embeds patchified image latents
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// x: [B, L, in_channels * 4] -> [B, L, dim]
return xe.Linear.Forward(x)
}
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// RMSNorm on last axis (uses 1e-6)
h := ce.Norm.Forward(capFeats, 1e-6)
// Linear projection
return ce.Linear.Forward(h)
}
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
h := mlx.Mul(gate, up)
out := ff.W2.Forward(h)
return mlx.Reshape(out, B, L, ff.OutDim)
}
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
}
// Forward computes attention
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
// QK norm
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
// Apply RoPE if provided
if cos != nil && sin != nil {
q = applyRoPE3D(q, cos, sin)
k = applyRoPE3D(k, cos, sin)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
// Transpose back and reshape
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B*L, attn.Dim)
out = attn.ToOut.Forward(out)
return mlx.Reshape(out, B, L, attn.Dim)
}
// applyRoPE3D applies 3-axis rotary position embeddings
// x: [B, L, nheads, head_dim]
// cos, sin: [B, L, 1, head_dim/2]
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Create even/odd index arrays
evenIdx := make([]int32, half)
oddIdx := make([]int32, half)
for i := int32(0); i < half; i++ {
evenIdx[i] = i * 2
oddIdx[i] = i*2 + 1
}
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
// Extract x1 (even indices) and x2 (odd indices) along last axis
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
return mlx.Reshape(stacked, B, L, nheads, headDim)
}
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
if tb.AdaLN != nil && adaln != nil {
// Compute modulation: [B, 256] -> [B, 4*dim]
chunks := tb.AdaLN.Forward(adaln)
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
chunkShape := chunks.Shape()
chunkDim := chunkShape[1] / 4
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
// Expand for broadcasting: [B, 1, dim]
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
gateMSA = mlx.ExpandDims(gateMSA, 1)
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
gateMLP = mlx.ExpandDims(gateMLP, 1)
// Attention with modulation
normX := tb.AttentionNorm1.Forward(x, eps)
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
attnOut := tb.Attention.Forward(normX, cos, sin)
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
// FFN with modulation
normFFN := tb.FFNNorm1.Forward(x, eps)
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
ffnOut := tb.FeedForward.Forward(normFFN)
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
} else {
// No modulation (context refiner)
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
}
return x
}
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
// Forward computes final output
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
// c: [B, 256] -> scale: [B, dim]
scale := mlx.SiLU(c)
scale = fl.AdaLN.Forward(scale)
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
// LayerNorm (affine=False) then scale
x = layerNormNoAffine(x, 1e-6)
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
// Output projection
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
x = mlx.Reshape(x, B*L, D)
x = fl.Output.Forward(x)
return mlx.Reshape(x, B, L, fl.OutDim)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Z-Image DiT model
type Transformer struct {
TEmbed *TimestepEmbedder `weight:"t_embedder"`
XEmbed *XEmbedder `weight:"all_x_embedder"`
CapEmbed *CapEmbedder `weight:"cap_embedder"`
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
Layers []*TransformerBlock `weight:"layers"`
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
XPadToken *mlx.Array `weight:"x_pad_token"`
CapPadToken *mlx.Array `weight:"cap_pad_token"`
*TransformerConfig
}
// Load loads the Z-Image transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Z-Image transformer...")
// Load config
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = cfg
// Pre-allocate slices for loader
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading weights via struct tags... ")
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
fmt.Println("✓")
// Initialize computed fields
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.ContextRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
weights.ReleaseAll()
return nil
}
// loadTransformerConfig loads transformer config from a JSON file
func loadTransformerConfig(path string) (*TransformerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg TransformerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Extract PatchSize from array
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
return &cfg, nil
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
block.HasModulation = block.AdaLN != nil
// Init attention computed fields
attn := block.Attention
attn.NHeads = cfg.NHeads
attn.HeadDim = cfg.Dim / cfg.NHeads
attn.Dim = cfg.Dim
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
block.AttentionNorm2.Eps = cfg.NormEps
block.FFNNorm1.Eps = cfg.NormEps
block.FFNNorm2.Eps = cfg.NormEps
}
// RoPECache holds precomputed RoPE values
type RoPECache struct {
ImgCos *mlx.Array
ImgSin *mlx.Array
CapCos *mlx.Array
CapSin *mlx.Array
UnifiedCos *mlx.Array
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
imgLen := hTok * wTok
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
imgPos = mlx.ToBFloat16(imgPos)
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
capPos = mlx.ToBFloat16(capPos)
// Compute RoPE from UNIFIED positions
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
// Slice RoPE for image and caption parts
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
return &RoPECache{
ImgCos: imgCos,
ImgSin: imgSin,
CapCos: capCos,
CapSin: capSin,
UnifiedCos: unifiedCos,
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
}
}
// Forward runs the Z-Image transformer with precomputed RoPE
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
imgLen := rope.ImgLen
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Embed caption features -> [B, L_cap, dim]
capEmb := m.CapEmbed.Forward(capFeats)
eps := m.NormEps
// Noise refiner: refine image patches with modulation
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Context refiner: refine caption (no modulation)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Main transformer layers use full unified RoPE
for _, layer := range m.Layers {
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// ForwardWithCache runs the transformer with layer caching for faster inference.
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
func (m *Transformer) ForwardWithCache(
x *mlx.Array,
t *mlx.Array,
capFeats *mlx.Array,
rope *RoPECache,
stepCache *cache.StepCache,
step int,
cacheInterval int,
) *mlx.Array {
imgLen := rope.ImgLen
cacheLayers := stepCache.NumLayers()
eps := m.NormEps
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Context refiners: compute once on step 0, reuse forever
// (caption embedding doesn't depend on timestep or latents)
var capEmb *mlx.Array
if stepCache.GetConstant() != nil {
capEmb = stepCache.GetConstant()
} else {
capEmb = m.CapEmbed.Forward(capFeats)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
stepCache.SetConstant(capEmb)
}
// Noise refiners: always compute (depend on x which changes each step)
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Determine if this is a cache refresh step
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
// Main transformer layers with caching
for i, layer := range m.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached output for shallow layers
unified = stepCache.Get(i)
} else {
// Compute layer
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
// Cache shallow layer outputs on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, unified)
}
}
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
// Create meshgrid and stack
total := d0 * d1 * d2
coords := make([]float32, total*3)
idx := 0
for i := int32(0); i < d0; i++ {
for j := int32(0); j < d1; j++ {
for k := int32(0); k < d2; k++ {
coords[idx*3+0] = float32(s0 + i)
coords[idx*3+1] = float32(s1 + j)
coords[idx*3+2] = float32(s2 + k)
idx++
}
}
}
return mlx.NewArray(coords, []int32{1, total, 3})
}
// prepareRoPE3D computes cos/sin for 3-axis RoPE
// positions: [B, L, 3] with (h, w, t) coordinates
// axesDims: [32, 48, 48] - dimensions for each axis
// Returns: cos, sin each [B, L, 1, head_dim/2]
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
// Compute frequencies for each axis
// dims = [32, 48, 48], so halves = [16, 24, 24]
ropeTheta := float32(256.0)
freqs := make([]*mlx.Array, 3)
for axis := 0; axis < 3; axis++ {
half := axesDims[axis] / 2
f := make([]float32, half)
for i := int32(0); i < half; i++ {
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
}
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
}
// Extract position coordinates
shape := positions.Shape()
B := shape[0]
L := shape[1]
// positions[:, :, 0] -> h positions
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
// Compute args: pos * freqs for each axis
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
posW = mlx.ExpandDims(posW, 3)
posT = mlx.ExpandDims(posT, 3)
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
// Concatenate: [B, L, 1, 16+24+24=64]
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
// Compute cos and sin
return mlx.Cos(args), mlx.Sin(args)
}
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize // H_tok
pW := W / patchSize // W_tok
// Match Python exactly: reshape treating B=1 as part of contiguous data
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
pH := H / patchSize
pW := W / patchSize
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
return mlx.Reshape(x, 1, C, H, W)
}

View File

@@ -1,650 +0,0 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"`
OutChannels int32 `json:"out_channels"`
LatentChannels int32 `json:"latent_channels"`
BlockOutChannels []int32 `json:"block_out_channels"`
LayersPerBlock int32 `json:"layers_per_block"`
NormNumGroups int32 `json:"norm_num_groups"`
ScalingFactor float32 `json:"scaling_factor"`
ShiftFactor float32 `json:"shift_factor"`
}
// loadVAEConfig loads VAE config from a JSON file
func loadVAEConfig(path string) (*VAEConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg VAEConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
Bias *mlx.Array
NumGroups int32
Eps float32
}
// NewGroupNorm creates a group norm layer
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
return &GroupNormLayer{
Weight: weight,
Bias: bias,
NumGroups: numGroups,
Eps: 1e-5,
}
}
// Forward applies group normalization
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, C, H, W]
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Reshape to [B, groups, C/groups, H, W]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
// Compute mean and variance per group
mean := mlx.Mean(x, 2, true)
mean = mlx.Mean(mean, 3, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
variance = mlx.Mean(variance, 3, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, C, H, W]
xNorm = mlx.Reshape(xNorm, B, C, H, W)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
Stride int32
Padding int32
}
// NewConv2D creates a Conv2D layer
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
// Transpose weight from OIHW to OHWI
// [O, I, H, W] -> [O, H, W, I]
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
return &Conv2D{
Weight: weightOHWI,
Bias: bias,
Stride: stride,
Padding: padding,
}
}
// Forward applies convolution
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// x: [N, C, H, W] -> [N, H, W, C]
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
// Conv in NHWC format
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer
Conv1 *Conv2D
Norm2 *GroupNormLayer
Conv2 *Conv2D
ConvShortcut *Conv2D // nil if in_channels == out_channels
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
}
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
if err != nil {
return nil, err
}
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
if err != nil {
return nil, err
}
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
if err != nil {
return nil, err
}
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
if err != nil {
return nil, err
}
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
if err != nil {
return nil, err
}
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
if err != nil {
return nil, err
}
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
if err != nil {
return nil, err
}
block := &ResnetBlock2D{
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
}
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
if err != nil {
return nil, err
}
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
if err != nil {
return nil, err
}
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
}
return block, nil
}
// Forward applies the ResNet block with staged evaluation
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 3: norm2
{
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: silu + conv2
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Residual connection
{
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
} else {
h = mlx.Add(h, x)
}
prev.Free()
mlx.Eval(h)
}
return h
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer
ToQWeight *mlx.Array
ToQBias *mlx.Array
ToKWeight *mlx.Array
ToKBias *mlx.Array
ToVWeight *mlx.Array
ToVBias *mlx.Array
ToOutWeight *mlx.Array
ToOutBias *mlx.Array
NumHeads int32
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
}
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
if err != nil {
return nil, err
}
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
if err != nil {
return nil, err
}
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
if err != nil {
return nil, err
}
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
if err != nil {
return nil, err
}
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
if err != nil {
return nil, err
}
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
if err != nil {
return nil, err
}
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
if err != nil {
return nil, err
}
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
if err != nil {
return nil, err
}
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
if err != nil {
return nil, err
}
return &VAEAttentionBlock{
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
ToQBias: toQBias,
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
ToKBias: toKBias,
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
ToVBias: toVBias,
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
ToOutBias: toOutBias,
NumHeads: 1,
}, nil
}
// Forward applies attention with staged evaluation
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape
{
h = ab.GroupNorm.Forward(x)
h = mlx.Transpose(h, 0, 2, 3, 1)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
var out *mlx.Array
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
v := mlx.Linear(h, ab.ToVWeight)
v = mlx.Add(v, ab.ToVBias)
h.Free()
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
mlx.Eval(out)
}
// Stage 3: Output projection + reshape + residual
{
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Transpose(out, 0, 3, 1, 2)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
}
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
if err != nil {
return nil, err
}
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
if err != nil {
return nil, err
}
upsample = NewConv2D(upWeight, upBias, 1, 1)
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Forward applies the up decoder block with staged evaluation to reduce peak memory
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
prev := x
x = resnet.Forward(x) // ResNet handles its own pools
prev.Free()
}
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
}
// Stage 2: Upsample conv
{
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
}
}
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// Forward applies the mid block with staged evaluation
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
prev.Free()
// Attention handles its own pools
prev = x
x = mb.Attention.Forward(x)
prev.Free()
prev = x
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
ConvIn *Conv2D
MidBlock *VAEMidBlock
UpBlocks []*UpDecoderBlock2D
ConvNormOut *GroupNormLayer
ConvOut *Conv2D
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading VAE decoder...")
// Load config
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = cfg
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load conv_in
fmt.Print(" Loading conv_in... ")
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
if err != nil {
return err
}
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
if err != nil {
return err
}
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
fmt.Println("✓")
// Load mid block
fmt.Print(" Loading mid block... ")
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return err
}
fmt.Println("✓")
// Load up blocks
fmt.Print(" Loading up blocks... ")
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return err
}
}
fmt.Printf("✓ [%d blocks]\n", numBlocks)
// Load conv_norm_out
fmt.Print(" Loading conv_norm_out... ")
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
if err != nil {
return err
}
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
if err != nil {
return err
}
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
fmt.Println("✓")
// Load conv_out
fmt.Print(" Loading conv_out... ")
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
if err != nil {
return err
}
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
if err != nil {
return err
}
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Decode decodes latents to images.
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
var h *mlx.Array
{
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
h = vae.ConvIn.Forward(z)
mlx.Eval(h)
}
h = vae.MidBlock.Forward(h)
for _, upBlock := range vae.UpBlocks {
h = upBlock.Forward(h)
}
{
prev := h
h = vae.ConvNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
prev.Free()
mlx.Eval(h)
}
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
// x: [B, C, H, W] -> [B, C, H*2, W*2]
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H, 1, W, 1]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Broadcast to [B, C, H, 2, W, 2]
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
// Reshape to [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H*2, W*2)
return x
}

View File

@@ -1,361 +0,0 @@
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Z-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Z-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.TransformerConfig
latentH := cfg.Height / 8
latentW := cfg.Width / 8
hTok := latentH / tcfg.PatchSize
wTok := latentW / tcfg.PatchSize
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
}
// Pad both to same length (multiple of 32)
maxLen := posEmb.Shape()[1]
if useCFG && negEmb.Shape()[1] > maxLen {
maxLen = negEmb.Shape()[1]
}
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
maxLen += pad
}
posEmb = padToLength(posEmb, maxLen)
if useCFG {
negEmb = padToLength(negEmb, maxLen)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Scheduler
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
// Init latents [B, C, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
return decoded, nil
}
// padToLength pads a sequence tensor to the target length by repeating the last token.
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
}
// CalculateShift computes the mu shift value for dynamic scheduling
func CalculateShift(imgSeqLen int32) float32 {
baseSeqLen := float32(256)
maxSeqLen := float32(4096)
baseShift := float32(0.5)
maxShift := float32(1.15)
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
b := baseShift - m*baseSeqLen
return float32(imgSeqLen)*m + b
}

View File

@@ -1,201 +0,0 @@
// Package nn provides neural network layer types.
package nn
import "github.com/ollama/ollama/x/imagegen/mlx"
// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
Weight *mlx.Array `weight:"weight"` // [out_features, in_features]
Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil
}
// NewLinear creates a linear layer.
// Weight should be [out_features, in_features].
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
mlx.Eval(qw, scales, qbiases)
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
// Forward applies the linear transformation: x @ W.T + bias
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
w := mlx.Transpose(l.Weight, 1, 0)
if l.Bias != nil {
return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
}
return mlx.Linear(x, w)
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: l.Bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (NOT layer bias)
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
Mode string
}
// Forward applies the quantized linear transformation.
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
if ql.Bias != nil {
out = mlx.Add(out, ql.Bias)
}
return out
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32 // optional: used if Forward called with eps=0
}
// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
return &RMSNorm{Weight: weight, Eps: eps}
}
// Forward applies RMS normalization. If eps=0, uses stored Eps.
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
if eps == 0 {
eps = rn.Eps
}
return mlx.RMSNorm(x, rn.Weight, eps)
}
// Embedding represents an embedding layer.
type Embedding struct {
Weight *mlx.Array `weight:"weight"`
}
// NewEmbedding creates an embedding layer.
func NewEmbedding(weight *mlx.Array) *Embedding {
return &Embedding{Weight: weight}
}
// Forward looks up embeddings by indices.
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return mlx.Take(e.Weight, indices, 0)
}
// RepeatKV repeats K/V tensors for grouped query attention
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
if repeatFactor == 1 {
return x
}
shape := x.Shape()
// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
x = mlx.ExpandDims(x, 2)
// Repeat along the new axis
reps := []int32{1, 1, repeatFactor, 1, 1}
x = mlx.Tile(x, reps)
// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
// scores: [B, num_heads, S, S]
shape := scores.Shape()
seqLen := shape[2]
// Create causal mask: 1 for positions to keep, 0 for positions to mask
mask := mlx.Tri(seqLen, seqLen, 0)
// Where mask is 0, set score to -inf
negInf := mlx.NewScalarArray(float32(-1e9))
// Broadcast mask to match scores shape
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]
// Use where: if mask > 0, keep scores, else -inf
return mlx.Where(mask, scores, negInf)
}
// ApplyCausalMaskWithOffset applies causal mask for cached attention
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
// offset: the starting position of the new queries (i.e., cache length)
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
if offset == 0 {
return ApplyCausalMask(scores)
}
shape := scores.Shape()
queryLen := shape[2]
keyLen := shape[3]
// For cached attention, new queries can attend to all cached keys plus
// new keys up to and including their position.
mask := mlx.Tri(queryLen, keyLen, int(offset))
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]
return mlx.Where(mask, scores, negInf)
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Eps float32
}
// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
eps := ln.Eps
if eps == 0 {
eps = 1e-5
}
// Compute mean and variance along last dimension
mean := mlx.Mean(x, -1, true)
centered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))
// Scale and shift
out := mlx.Mul(normalized, ln.Weight)
if ln.Bias != nil {
out = mlx.Add(out, ln.Bias)
}
return out
}

View File

@@ -1,354 +0,0 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
func TestLinearNoBias(t *testing.T) {
// Weight: [out=2, in=3] -> transposed at forward time
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3, // row 0
4, 5, 6, // row 1
}, []int32{2, 3})
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Input: [1, 3]
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
data := out.Data()
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
t.Errorf("expected [6, 15], got %v", data)
}
}
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
func TestLinearWithBias(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3,
4, 5, 6,
}, []int32{2, 3})
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
mlx.Eval(weight, bias)
linear := NewLinear(weight, bias)
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [6, 15] + [10, 20] = [16, 35]
data := out.Data()
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
t.Errorf("expected [16, 35], got %v", data)
}
}
// TestLinearBatched verifies Linear works with batched input.
func TestLinearBatched(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 0,
0, 1,
}, []int32{2, 2}) // Identity
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Batch of 3 inputs
x := mlx.NewArrayFloat32([]float32{
1, 2,
3, 4,
5, 6,
}, []int32{3, 2})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Identity should return same values
data := out.Data()
expected := []float32{1, 2, 3, 4, 5, 6}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRMSNorm verifies RMSNorm computation.
func TestRMSNorm(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
// Input with known RMS
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-1.0)) > 1e-4 {
t.Errorf("at %d: expected ~1.0, got %f", i, v)
}
}
}
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
func TestRMSNormWithScale(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-2.0)) > 1e-4 {
t.Errorf("at %d: expected ~2.0, got %f", i, v)
}
}
}
// TestEmbedding verifies embedding lookup.
func TestEmbedding(t *testing.T) {
// Embedding table: 4 tokens, dim 3
weight := mlx.NewArrayFloat32([]float32{
0, 0, 0, // token 0
1, 1, 1, // token 1
2, 2, 2, // token 2
3, 3, 3, // token 3
}, []int32{4, 3})
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Look up tokens [1, 3, 0]
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
mlx.Eval(indices)
out := emb.Forward(indices)
mlx.Eval(out)
data := out.Data()
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKV verifies K/V repetition for GQA.
func TestRepeatKV(t *testing.T) {
// [B=1, num_kv_heads=2, S=2, head_dim=2]
x := mlx.NewArrayFloat32([]float32{
// head 0
1, 2, // pos 0
3, 4, // pos 1
// head 1
5, 6, // pos 0
7, 8, // pos 1
}, []int32{1, 2, 2, 2})
mlx.Eval(x)
// Repeat factor 2: 2 kv heads -> 4 heads
out := RepeatKV(x, 2)
mlx.Eval(out)
shape := out.Shape()
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
t.Errorf("expected shape [1,4,2,2], got %v", shape)
}
data := out.Data()
// After repeat: head0, head0, head1, head1
expected := []float32{
1, 2, 3, 4, // head 0 (original)
1, 2, 3, 4, // head 0 (repeat)
5, 6, 7, 8, // head 1 (original)
5, 6, 7, 8, // head 1 (repeat)
}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
func TestRepeatKVNoOp(t *testing.T) {
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
mlx.Eval(x)
out := RepeatKV(x, 1)
// Should return same pointer
if out != x {
t.Error("RepeatKV with factor 1 should return input unchanged")
}
}
// TestApplyCausalMask verifies causal masking.
func TestApplyCausalMask(t *testing.T) {
// [B=1, heads=1, S=3, S=3] - all ones
scores := mlx.Ones(1, 1, 3, 3)
mlx.Eval(scores)
out := ApplyCausalMask(scores)
mlx.Eval(out)
data := out.Data()
// Lower triangular should be 1, upper should be -1e9
// Row 0: [1, -inf, -inf]
// Row 1: [1, 1, -inf]
// Row 2: [1, 1, 1]
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:3])
}
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
t.Errorf("row 1 wrong: %v", data[3:6])
}
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
t.Errorf("row 2 wrong: %v", data[6:9])
}
}
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
func TestApplyCausalMaskWithOffset(t *testing.T) {
// Simulating: cache has 2 tokens, adding 1 new query
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
scores := mlx.Ones(1, 1, 1, 3)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 2)
mlx.Eval(out)
data := out.Data()
// With offset=2, query at position 2 can attend to all 3 positions
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
t.Errorf("expected [1, 1, 1], got %v", data)
}
}
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
scores := mlx.Ones(1, 1, 2, 2)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 0)
mlx.Eval(out)
data := out.Data()
// Standard causal: [1, -inf], [1, 1]
if data[0] != 1 || data[1] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:2])
}
if data[2] != 1 || data[3] != 1 {
t.Errorf("row 1 wrong: %v", data[2:4])
}
}
// BenchmarkLinearSmall benchmarks small Linear forward pass.
func BenchmarkLinearSmall(b *testing.B) {
weight := mlx.RandomNormal([]int32{256, 256}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 256}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
func BenchmarkLinearLarge(b *testing.B) {
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 4096}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
func BenchmarkRMSNorm(b *testing.B) {
weight := mlx.Ones(4096)
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.RandomNormal([]int32{1, 4096}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := norm.Forward(x, 0)
mlx.Eval(out)
}
}
// BenchmarkEmbedding benchmarks embedding lookup.
func BenchmarkEmbedding(b *testing.B) {
// Typical vocab size
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Single token lookup
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
mlx.Eval(indices)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := emb.Forward(indices)
mlx.Eval(out)
}
}
// BenchmarkRepeatKV benchmarks K/V repetition.
func BenchmarkRepeatKV(b *testing.B) {
// Typical GQA setup: 8 kv heads -> 32 heads
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := RepeatKV(x, 4)
mlx.Eval(out)
}
}

View File

@@ -1,168 +0,0 @@
package safetensors
import (
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
// - path: the weight name suffix (appended to prefix)
// - optional: if present, missing weights don't cause errors
// - "-": skip this field entirely
// - no tag on struct pointer: recurse with current prefix
// - no tag on *mlx.Array: skip (computed fields don't need loading)
//
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
// The slice must be pre-allocated to the correct length.
//
// Example:
//
// type Attention struct {
// QProj *nn.Linear `weight:"self_attn.q_proj"`
// KProj *nn.Linear `weight:"self_attn.k_proj"`
// Cache *mlx.Array // no tag = skipped (computed field)
// }
//
// err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
}
var errs []string
loadStruct(v, weights, prefix, &errs, false)
if len(errs) > 0 {
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
}
return nil
}
// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldVal := v.Field(i)
// Skip unexported fields
if !fieldVal.CanSet() {
continue
}
// Parse tag
tag, hasTag := field.Tag.Lookup("weight")
if tag == "-" {
continue
}
// Parse tag options
optional := parentOptional
weightPath := tag
if idx := strings.Index(tag, ","); idx != -1 {
weightPath = tag[:idx]
if strings.Contains(tag[idx+1:], "optional") {
optional = true
}
}
// Build full path
fullPath := joinPath(prefix, weightPath)
// For struct pointers without a tag, recurse with current prefix
if !hasTag && fieldVal.Kind() == reflect.Ptr {
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
continue
}
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
elemType := fieldVal.Type().Elem()
// *mlx.Array - load directly (but skip if no tag - computed fields)
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
if !hasTag {
continue // no tag on *mlx.Array = computed field, skip
}
arr, err := weights.GetTensor(fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath)
}
continue
}
fieldVal.Set(reflect.ValueOf(arr))
continue
}
// Pointer to struct - allocate and recurse
if elemType.Kind() == reflect.Struct {
if optional && !hasWeightsWithPrefix(weights, fullPath) {
continue
}
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
}
case reflect.Slice:
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
loadSlice(fieldVal, weights, fullPath, errs)
}
}
}
}
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
for _, name := range weights.ListTensors() {
if strings.HasPrefix(name, prefix+".") || name == prefix {
return true
}
}
return false
}
// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
elemStructType := v.Type().Elem().Elem()
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.IsNil() {
elem.Set(reflect.New(elemStructType))
}
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
}
}
// joinPath joins path segments with dots, handling empty segments.
func joinPath(prefix, suffix string) string {
if prefix == "" {
return suffix
}
if suffix == "" {
return prefix
}
return prefix + "." + suffix
}

View File

@@ -1,278 +0,0 @@
package safetensors
import (
"encoding/binary"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SafetensorHeader represents the JSON header of a safetensors file
type SafetensorHeader map[string]TensorInfo
// TensorInfo contains metadata about a tensor
type TensorInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// parseSafetensorHeader reads only the JSON header from a safetensors file.
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
headerBytes := make([]byte, headerSize)
if _, err := f.Read(headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
var header SafetensorHeader
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
delete(header, "__metadata__")
return header, nil
}
// dtypeFromString converts safetensors dtype string to mlx.Dtype
func dtypeFromString(s string) mlx.Dtype {
switch strings.ToUpper(s) {
case "F32", "FLOAT32":
return mlx.DtypeFloat32
case "F16", "FLOAT16":
return mlx.DtypeFloat16
case "BF16", "BFLOAT16":
return mlx.DtypeBFloat16
case "I32", "INT32":
return mlx.DtypeInt32
case "I64", "INT64":
return mlx.DtypeInt64
case "U8", "UINT8":
return mlx.DtypeUint8
default:
return mlx.DtypeFloat32
}
}
// ModelWeights manages weights from multiple safetensor files.
type ModelWeights struct {
dir string // Model directory
tensorFiles map[string]string // tensor name -> file path
tensorInfo map[string]TensorInfo // tensor name -> metadata
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
cache map[string]*mlx.Array // tensor name -> array (after Load)
}
// LoadModelWeights scans safetensor files and builds a tensor index.
// This only reads JSON headers, not tensor data.
func LoadModelWeights(dir string) (*ModelWeights, error) {
mw := &ModelWeights{
dir: dir,
tensorFiles: make(map[string]string),
tensorInfo: make(map[string]TensorInfo),
nativeCache: make(map[string]*mlx.SafetensorsFile),
}
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
path := filepath.Join(dir, entry.Name())
header, err := parseSafetensorHeader(path)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
}
for name, info := range header {
mw.tensorFiles[name] = path
mw.tensorInfo[name] = info
}
}
}
if len(mw.tensorFiles) == 0 {
return nil, fmt.Errorf("no safetensor files found in %s", dir)
}
return mw, nil
}
// Load loads all tensors into cache with the specified dtype.
// If dtype is 0, tensors are loaded in their original dtype.
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
// or native loading when tensors are already in the target dtype.
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
if dtype == 0 {
return mw.loadNative()
}
// Check if any tensor needs conversion
needsConversion := false
for name := range mw.tensorFiles {
info := mw.tensorInfo[name]
if dtypeFromString(info.Dtype) != dtype {
needsConversion = true
break
}
}
if needsConversion {
return mw.loadStreaming(dtype)
}
return mw.loadNative()
}
// loadNative loads all tensors using the native memory-mapped loader.
func (mw *ModelWeights) loadNative() error {
mw.cache = make(map[string]*mlx.Array)
fileToTensors := make(map[string][]string)
for name, path := range mw.tensorFiles {
fileToTensors[path] = append(fileToTensors[path], name)
}
for path, names := range fileToTensors {
native, err := mlx.LoadSafetensorsNative(path)
if err != nil {
return fmt.Errorf("failed to load %s: %w", path, err)
}
for _, name := range names {
arr := native.Get(name)
if arr == nil {
native.Free()
return fmt.Errorf("tensor %q not found in %s", name, path)
}
mw.cache[name] = arr
}
mw.nativeCache[path] = native
}
return nil
}
// loadStreaming loads tensors with dtype conversion.
// Uses the same pattern as Python: replace each entry in the map after conversion,
// so the original tensor loses its reference and can be freed.
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
mw.cache = make(map[string]*mlx.Array)
fileToTensors := make(map[string][]string)
for name, path := range mw.tensorFiles {
fileToTensors[path] = append(fileToTensors[path], name)
}
for path, names := range fileToTensors {
native, err := mlx.LoadSafetensorsNative(path)
if err != nil {
return fmt.Errorf("failed to load %s: %w", path, err)
}
for _, name := range names {
src := native.Get(name)
if src == nil {
native.Free()
return fmt.Errorf("tensor %q not found in %s", name, path)
}
dst := mlx.AsType(src, dtype)
mlx.Eval(dst)
native.Set(name, dst)
mw.cache[name] = dst
}
native.Free()
}
return nil
}
// Get returns a tensor from cache. Call Load() first.
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
if mw.cache == nil {
return nil, fmt.Errorf("cache not initialized: call Load() first")
}
arr, ok := mw.cache[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found in cache", name)
}
return arr, nil
}
// GetTensor loads a tensor using the native loader without caching.
// For bulk loading, use Load() + Get() instead.
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
if mw.cache != nil {
if arr, ok := mw.cache[name]; ok {
return arr, nil
}
}
path, ok := mw.tensorFiles[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found", name)
}
native, ok := mw.nativeCache[path]
if !ok {
var err error
native, err = mlx.LoadSafetensorsNative(path)
if err != nil {
return nil, fmt.Errorf("failed to load %s: %w", path, err)
}
mw.nativeCache[path] = native
}
return native.Get(name), nil
}
// GetTensorInfo returns metadata about a tensor without loading it.
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
info, ok := mw.tensorInfo[name]
return info, ok
}
// ListTensors returns all tensor names.
func (mw *ModelWeights) ListTensors() []string {
names := make([]string, 0, len(mw.tensorFiles))
for name := range mw.tensorFiles {
names = append(names, name)
}
sort.Strings(names)
return names
}
// HasTensor checks if a tensor exists.
func (mw *ModelWeights) HasTensor(name string) bool {
_, ok := mw.tensorFiles[name]
return ok
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {
native.Free()
delete(mw.nativeCache, path)
}
}

View File

@@ -1,165 +0,0 @@
package safetensors
import (
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestLoadModelWeights(t *testing.T) {
// Skip if no model available
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Check we found tensors
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Fatal("no tensors found")
}
t.Logf("found %d tensors", len(tensors))
// Check HasTensor
if !mw.HasTensor(tensors[0]) {
t.Errorf("HasTensor(%q) = false", tensors[0])
}
if mw.HasTensor("nonexistent.weight") {
t.Error("HasTensor returned true for nonexistent tensor")
}
}
func TestGetTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Skip("no tensors")
}
// Load first tensor
arr, err := mw.GetTensor(tensors[0])
if err != nil {
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
}
// Verify it has a shape
shape := arr.Shape()
if len(shape) == 0 {
t.Error("tensor has no shape")
}
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
}
func TestLoadWithDtype(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Load all tensors as bfloat16
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Load: %v", err)
}
// Get a tensor from cache
tensors := mw.ListTensors()
arr, err := mw.Get(tensors[0])
if err != nil {
t.Fatalf("Get: %v", err)
}
// Verify dtype (unless it was already bf16)
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
}
func TestLookupTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// HasTensor returns false for nonexistent
if mw.HasTensor("nonexistent") {
t.Error("HasTensor should return false for nonexistent")
}
// HasTensor returns true for existing tensor
tensors := mw.ListTensors()
if !mw.HasTensor(tensors[0]) {
t.Error("HasTensor should return true for existing tensor")
}
}
func TestParseSafetensorHeader(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
// Find a safetensors file
entries, err := os.ReadDir(modelDir)
if err != nil {
t.Fatal(err)
}
var stFile string
for _, e := range entries {
if filepath.Ext(e.Name()) == ".safetensors" {
stFile = filepath.Join(modelDir, e.Name())
break
}
}
if stFile == "" {
t.Skip("no safetensors file found")
}
header, err := parseSafetensorHeader(stFile)
if err != nil {
t.Fatalf("parseSafetensorHeader: %v", err)
}
if len(header) == 0 {
t.Error("header is empty")
}
// Check a tensor has valid info
for name, info := range header {
if info.Dtype == "" {
t.Errorf("%s: empty dtype", name)
}
if len(info.Shape) == 0 {
t.Errorf("%s: empty shape", name)
}
break // just check one
}
}

View File

@@ -1,166 +0,0 @@
# Tokenizer
Fast, correct tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms.
## Features
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
- **SentencePiece** - Gemma style with `▁` space handling
- **WordPiece** - BERT style with `##` continuation tokens
- **Parallel encoding** - Automatic parallelization for inputs >4KB
- **HuggingFace compatible** - Loads `tokenizer.json` directly
## Usage
```go
import "github.com/ollama/ollama/x/imagegen/tokenizer"
// Load from HuggingFace model directory
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
if err != nil {
log.Fatal(err)
}
// Encode text to token IDs
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
// Decode back to text
text := tok.Decode(ids)
// Check special tokens
if tok.IsEOS(ids[len(ids)-1]) {
// End of sequence
}
```
## Performance
Benchmarks on Apple M3 Max:
| Input Size | Encode | Decode | Tokens |
|------------|--------|--------|--------|
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
Comparison with other implementations (10 MB input):
| Implementation | Encode Speed | Notes |
|----------------|--------------|-------|
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
| llama.cpp (C++) | ~2 MB/s | Single-threaded only |
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
## Correctness
The tokenizer matches HuggingFace transformers exactly. Verified with:
- 82 rigorous test cases for Gemma (SentencePiece)
- 458 fuzz test cases covering Unicode edge cases
- Full 0x00-0xFF byte roundtrip for BPE
Run tests:
```bash
go test ./tokenizer/... -v
```
## Architecture
```
Load(path)
├─ tokenizer.json → loadFromTokenizerJSON()
│ ├─ Parse vocab, merges, added_tokens
│ ├─ Detect type: BPE / SentencePiece / WordPiece
│ ├─ Compile pretokenizer regex (BPE only)
│ └─ Load special tokens from config files
└─ vocab.json + merges.txt → loadVocabMerges()
Encode(text)
├─ Split by special tokens
├─ Apply pretokenizer regex (BPE) or space→▁ (SentencePiece)
├─ For each chunk:
│ ├─ Fast path: single token lookup
│ └─ Slow path: BPE merge algorithm
└─ Parallel for inputs >4KB
Decode(ids)
├─ Look up each token
└─ Apply inverse transform:
├─ BPE: byte-level decode (0x0100 → 0x00, etc.)
├─ SentencePiece: ▁→space, <0xNN>→byte
└─ WordPiece: strip ## prefix
```
## Key Implementation Details
### BPE Byte-Level Encoding
GPT-2 style encoding maps bytes to Unicode codepoints to handle arbitrary binary data:
```go
// Precomputed table: byte → rune
var byteToRune [256]rune // 0x00→0x0100, 0x20→0x0120, etc.
```
### Pretokenizer Regex
HuggingFace patterns use PCRE features not supported by Go's RE2. We rewrite:
```go
// PCRE (HuggingFace)
`\s+(?!\S)|\s+`
// RE2 (Go) - with post-processing for whitespace boundaries
`\s+`
```
### Special Token Handling
Special tokens are matched greedily (longest first) before pretokenization:
```go
// Sorted by length, checked with HasPrefix
for _, tok := range sortedSpecialTokens {
if strings.HasPrefix(remaining, tok) {
// Found special token
}
}
```
## Performance Opportunities
Potential optimizations not yet implemented:
| Optimization | Expected Gain | Complexity |
|--------------|---------------|------------|
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
| Custom regex engine (like tiktoken) | 1.5-2x | High |
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
| Assembly BPE merge loop | 1.2-1.5x | High |
| Memoization for repeated substrings | Variable | Low |
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
## Not Yet Implemented
| Feature | Used By | Notes |
|---------|---------|-------|
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
| Custom pretokenizers | Model-specific | Beyond standard patterns |
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
## Files
| File | Description |
|------|-------------|
| `tokenizer.go` | Main implementation (~1000 lines) |
| `tokenizer_test.go` | Tests and benchmarks |
| `testdata/` | Mini tokenizer for unit tests |

Some files were not shown because too many files have changed in this diff Show More