mirror of
https://github.com/ollama/ollama.git
synced 2026-01-08 15:39:54 -05:00
Compare commits
13 Commits
mlx-engine
...
implement-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14499406d2 | ||
|
|
2b5093e2e7 | ||
|
|
f4537dd113 | ||
|
|
34257b6e37 | ||
|
|
01d12cd98f | ||
|
|
9a16ad3857 | ||
|
|
b14ba5285f | ||
|
|
4ec2873ed1 | ||
|
|
e9f4ef84fb | ||
|
|
3531fb5970 | ||
|
|
f5a85e8ac6 | ||
|
|
214563ab17 | ||
|
|
2b90199b91 |
@@ -12,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -147,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,28 +83,6 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -162,21 +140,6 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
37
Dockerfile
37
Dockerfile
@@ -131,40 +131,7 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
# TODO wire up the actual MLX engine here instead of building the main binary...
|
||||
RUN mkdir -p dist/bin
|
||||
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx-engine .
|
||||
RUN go build -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
|
||||
FROM base AS build
|
||||
@@ -186,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
778
anthropic/anthropic.go
Normal file
778
anthropic/anthropic.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
953
anthropic/anthropic_test.go
Normal file
953
anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -520,7 +520,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("yolo")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -548,9 +547,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
// Use experimental agent loop with
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -1765,7 +1764,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -6,14 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -21,28 +18,9 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
|
||||
// TODO vision config
|
||||
/*
|
||||
"vision_config": {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"vision_use_head": false
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
type AdapterParameters struct {
|
||||
@@ -55,91 +33,8 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -168,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() KV {
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -176,7 +71,7 @@ func (p AdapterParameters) KV() KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := KV{
|
||||
kv := ggml.KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -193,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -217,7 +107,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ofs.Config) KV
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -225,7 +115,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -236,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -263,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
return errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -323,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -360,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -382,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -7,7 +9,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,6 +3,8 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -53,7 +55,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -19,13 +18,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
@@ -34,7 +34,7 @@ type olmoModel struct {
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
for k, v := range kv {
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV KV
|
||||
BaseKV map[string]any
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
406
docs/api/anthropic-compatibility.mdx
Normal file
406
docs/api/anthropic-compatibility.mdx
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -32,7 +32,9 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy"]
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -52,7 +54,9 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": ["curl"]
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -97,6 +101,7 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -139,7 +144,8 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility"
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
69
docs/integrations/claude-code.mdx
Normal file
69
docs/integrations/claude-code.mdx
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,7 +1,5 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -13,8 +11,4 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
149
middleware/anthropic.go
Normal file
149
middleware/anthropic.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
584
middleware/anthropic_test.go
Normal file
584
middleware/anthropic_test.go
Normal file
@@ -0,0 +1,584 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err anthropic.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
stream := true
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "basic message",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with system prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are helpful.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop": []string{"\n", "END"},
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &stream,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testProps(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tool result",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with thinking enabled",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "model is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing max_tokens error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "max_tokens is required and must be positive",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing messages error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "messages is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_use missing id error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "name": "test"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "tool_use block missing required 'id' field",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Type != "" {
|
||||
// Expect error
|
||||
if resp.Code == http.StatusOK {
|
||||
t.Fatalf("expected error response, got 200 OK")
|
||||
}
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
if errResp.Type != tc.err.Type {
|
||||
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tc.err.Error.Type {
|
||||
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tc.err.Error.Message {
|
||||
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("request was not captured")
|
||||
}
|
||||
|
||||
// Compare relevant fields
|
||||
if capturedRequest.Model != tc.req.Model {
|
||||
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
||||
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
||||
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
||||
if *tc.req.Stream != *capturedRequest.Stream {
|
||||
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.req.Think != nil {
|
||||
if capturedRequest.Think == nil {
|
||||
t.Error("expected Think to be set")
|
||||
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
||||
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("streaming sets correct headers", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Check headers were set
|
||||
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
||||
}
|
||||
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate Ollama response
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result anthropic.MessagesResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
||||
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
||||
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
errorPayload any
|
||||
wantErrorType string
|
||||
wantMessage string
|
||||
}{
|
||||
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
||||
{
|
||||
name: "404 with gin.H error (model not found)",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "400 with gin.H error (bad request)",
|
||||
statusCode: http.StatusBadRequest,
|
||||
errorPayload: gin.H{"error": "model is required"},
|
||||
wantErrorType: "invalid_request_error",
|
||||
wantMessage: "model is required",
|
||||
},
|
||||
{
|
||||
name: "500 with gin.H error (internal error)",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
errorPayload: gin.H{"error": "something went wrong"},
|
||||
wantErrorType: "api_error",
|
||||
wantMessage: "something went wrong",
|
||||
},
|
||||
{
|
||||
name: "404 with api.StatusError",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: api.StatusError{
|
||||
StatusCode: http.StatusNotFound,
|
||||
ErrorMessage: "model not found via StatusError",
|
||||
},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model not found via StatusError",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate what routes.go does - set status and write error JSON
|
||||
data, _ := json.Marshal(tt.errorPayload)
|
||||
c.Writer.WriteHeader(tt.statusCode)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.statusCode {
|
||||
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tt.wantErrorType {
|
||||
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tt.wantMessage {
|
||||
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -802,8 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV
|
||||
base = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -6,9 +6,6 @@ import (
|
||||
|
||||
var ErrInterrupt = errors.New("Interrupt")
|
||||
|
||||
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
|
||||
var ErrExpandOutput = errors.New("ExpandOutput")
|
||||
|
||||
type InterruptError struct {
|
||||
Line []rune
|
||||
}
|
||||
|
||||
@@ -206,9 +206,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.DeleteBefore()
|
||||
case CharCtrlL:
|
||||
buf.ClearScreen()
|
||||
case CharCtrlO:
|
||||
// Ctrl+O - expand tool output
|
||||
return "", ErrExpandOutput
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
|
||||
@@ -18,7 +18,6 @@ const (
|
||||
CharCtrlL = 12
|
||||
CharEnter = 13
|
||||
CharNext = 14
|
||||
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
|
||||
CharPrev = 16
|
||||
CharBckSearch = 18
|
||||
CharFwdSearch = 19
|
||||
|
||||
@@ -37,55 +37,6 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
echo "Compressing linux tar bundles..."
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
@@ -455,7 +454,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
|
||||
@@ -1544,6 +1544,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -42,8 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV
|
||||
base = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -381,28 +381,6 @@ func (t templateTools) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateArgs is a map type with JSON string output for templates.
|
||||
type templateArgs map[string]any
|
||||
|
||||
func (t templateArgs) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateProperties is a map type with JSON string output for templates.
|
||||
type templateProperties map[string]api.ToolProperty
|
||||
|
||||
func (t templateProperties) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
@@ -418,11 +396,11 @@ type templateToolFunction struct {
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties templateProperties `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
@@ -435,7 +413,7 @@ type templateToolCall struct {
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments templateArgs
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "Tokyo")
|
||||
args.Set("unit", "celsius")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Arguments output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:{...}]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Properties output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsRange(t *testing.T) {
|
||||
// Test that we can range over Arguments in templates
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("city", "Tokyo")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "city=Tokyo;" {
|
||||
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesRange(t *testing.T) {
|
||||
// Test that we can range over Properties in templates
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "location:string;" {
|
||||
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ package agent
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -180,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For commands without path args, returns empty string.
|
||||
// Paths with ".." traversal that escape the base directory return empty string for security.
|
||||
func extractBashPrefix(command string) string {
|
||||
// Split command by pipes and get the first part
|
||||
parts := strings.Split(command, "|")
|
||||
@@ -206,8 +204,8 @@ func extractBashPrefix(command string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or \ or start with .)
|
||||
// First pass: look for clear paths (containing path separators or starting with .)
|
||||
// Find the first path-like argument (must contain / or start with .)
|
||||
// First pass: look for clear paths (containing / or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
@@ -217,49 +215,19 @@ func extractBashPrefix(command string) string {
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or \ or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
|
||||
// Only process if it looks like a path (contains / or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
}
|
||||
// Normalize to forward slashes for consistent cross-platform matching
|
||||
arg = strings.ReplaceAll(arg, "\\", "/")
|
||||
|
||||
// Security: reject absolute paths
|
||||
if path.IsAbs(arg) {
|
||||
return "" // Absolute path - don't create prefix
|
||||
// If arg ends with /, it's a directory - use it directly
|
||||
if strings.HasSuffix(arg, "/") {
|
||||
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||
}
|
||||
|
||||
// Normalize the path using stdlib path.Clean (resolves . and ..)
|
||||
cleaned := path.Clean(arg)
|
||||
|
||||
// Security: reject if cleaned path escapes to parent directory
|
||||
if strings.HasPrefix(cleaned, "..") {
|
||||
return "" // Path escapes - don't create prefix
|
||||
}
|
||||
|
||||
// Security: if original had "..", verify cleaned path didn't escape to sibling
|
||||
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
|
||||
if strings.Contains(arg, "..") {
|
||||
origBase := strings.SplitN(arg, "/", 2)[0]
|
||||
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
|
||||
if origBase != cleanedBase {
|
||||
return "" // Path escaped to sibling directory
|
||||
}
|
||||
}
|
||||
|
||||
// Check if arg ends with / (explicit directory)
|
||||
isDir := strings.HasSuffix(arg, "/")
|
||||
|
||||
// Get the directory part
|
||||
var dir string
|
||||
if isDir {
|
||||
dir = cleaned
|
||||
} else {
|
||||
dir = path.Dir(cleaned)
|
||||
}
|
||||
|
||||
// Get the directory part of a file path
|
||||
dir := filepath.Dir(arg)
|
||||
if dir == "." {
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||
}
|
||||
@@ -364,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
|
||||
}
|
||||
|
||||
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
|
||||
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
|
||||
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
@@ -376,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// For bash commands, check prefix matches with hierarchical path support
|
||||
// For bash commands, check prefix matches
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
// Check exact prefix match first
|
||||
if a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
// Check hierarchical match: if any stored prefix is a parent of current prefix
|
||||
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
|
||||
if a.matchesHierarchicalPrefix(prefix) {
|
||||
return true
|
||||
}
|
||||
if prefix != "" && a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -402,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
|
||||
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
|
||||
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
|
||||
// Split prefix into command and path parts (format: "cmd:path/")
|
||||
colonIdx := strings.Index(currentPrefix, ":")
|
||||
if colonIdx == -1 {
|
||||
return false
|
||||
}
|
||||
currentCmd := currentPrefix[:colonIdx]
|
||||
currentPath := currentPrefix[colonIdx+1:]
|
||||
|
||||
for storedPrefix := range a.prefixes {
|
||||
storedColonIdx := strings.Index(storedPrefix, ":")
|
||||
if storedColonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
storedCmd := storedPrefix[:storedColonIdx]
|
||||
storedPath := storedPrefix[storedColonIdx+1:]
|
||||
|
||||
// Commands must match exactly
|
||||
if currentCmd != storedCmd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if current path starts with stored path (hierarchical match)
|
||||
// e.g., "tools/subdir/" starts with "tools/"
|
||||
if strings.HasPrefix(currentPath, storedPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||
@@ -519,12 +443,11 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web search, show query and internet notice
|
||||
// For web search, show query
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
@@ -1028,79 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
|
||||
}
|
||||
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||
}
|
||||
|
||||
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
|
||||
// Returns true for Yes, false for No.
|
||||
func PromptYesNo(question string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
selected := 0 // 0 = Yes, 1 = No
|
||||
options := []string{"Yes", "No"}
|
||||
|
||||
// Hide cursor
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
||||
|
||||
renderYesNo := func() {
|
||||
// Move to start of line and clear
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
renderYesNo()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 'y', 'Y':
|
||||
selected = 0
|
||||
renderYesNo()
|
||||
case 'n', 'N':
|
||||
selected = 1
|
||||
renderYesNo()
|
||||
case '\r', '\n': // Enter
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
|
||||
return selected == 0, nil
|
||||
case 3: // Ctrl+C
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return false, nil
|
||||
case 27: // Escape - could be arrow key
|
||||
// Read more bytes for arrow keys
|
||||
continue
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
// Arrow keys
|
||||
switch buf[2] {
|
||||
case 'D': // Left
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
renderYesNo()
|
||||
case 'C': // Right
|
||||
if selected < len(options)-1 {
|
||||
selected++
|
||||
}
|
||||
renderYesNo()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
command: "head -n 100",
|
||||
expected: "",
|
||||
},
|
||||
// Path traversal security tests
|
||||
{
|
||||
name: "path traversal - parent escape",
|
||||
command: "cat tools/../../etc/passwd",
|
||||
expected: "", // Should NOT create a prefix - path escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - deep escape",
|
||||
command: "cat tools/a/b/../../../etc/passwd",
|
||||
expected: "", // Normalizes to "../etc/passwd" - escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - absolute path",
|
||||
command: "cat /etc/passwd",
|
||||
expected: "", // Absolute paths should not create prefix
|
||||
},
|
||||
{
|
||||
name: "path with safe dotdot - normalized",
|
||||
command: "cat tools/subdir/../file.go",
|
||||
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Path traversal attack: should NOT be allowed
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
|
||||
t.Error("SECURITY: path traversal attack should NOT be allowed")
|
||||
}
|
||||
|
||||
// Another traversal variant
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
|
||||
t.Error("SECURITY: deep path traversal should NOT be allowed")
|
||||
}
|
||||
|
||||
// Valid subdirectory access should still work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed")
|
||||
}
|
||||
|
||||
// Safe ".." that normalizes to within allowed directory should work
|
||||
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
|
||||
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should allow subdirectories (hierarchical matching)
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should allow deeply nested subdirectories
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
|
||||
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should still allow same directory
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
|
||||
t.Error("expected cat tools/another.go to be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different base directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||
t.Error("expected cat src/main.go to NOT be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different command even in subdirectory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
|
||||
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
|
||||
}
|
||||
|
||||
// Should NOT allow similar but different directory name
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
|
||||
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow with forward slashes (Unix-style)
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should work with backslashes too (Windows-style) - normalized internally
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
|
||||
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
|
||||
}
|
||||
|
||||
// Mixed slashes should also work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
|
||||
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesHierarchicalPrefix(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Add prefix for "cat:tools/"
|
||||
am.prefixes["cat:tools/"] = true
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
prefix: "cat:tools/",
|
||||
expected: true, // exact match also passes HasPrefix - caller handles exact match first
|
||||
},
|
||||
{
|
||||
name: "subdirectory",
|
||||
prefix: "cat:tools/subdir/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
prefix: "cat:tools/a/b/c/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different base directory",
|
||||
prefix: "cat:src/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different command same path",
|
||||
prefix: "ls:tools/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "similar directory name",
|
||||
prefix: "cat:toolsbin/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid prefix format",
|
||||
prefix: "cattools",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := am.matchesHierarchicalPrefix(tt.prefix)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
|
||||
tt.prefix, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatApprovalResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
263
x/cmd/run.go
263
x/cmd/run.go
@@ -6,12 +6,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
@@ -24,101 +22,6 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
localModelTokenLimit = 4000
|
||||
|
||||
// defaultTokenLimit is the token limit for cloud/remote models.
|
||||
defaultTokenLimit = 10000
|
||||
|
||||
// charsPerToken is a rough estimate of characters per token.
|
||||
// TODO: Estimate tokens more accurately using tokenizer if available
|
||||
charsPerToken = 4
|
||||
)
|
||||
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
// TODO: Could also check other indicators of local vs cloud server
|
||||
func isLocalServer() bool {
|
||||
host := os.Getenv("OLLAMA_HOST")
|
||||
if host == "" {
|
||||
return true // Default is localhost:11434
|
||||
}
|
||||
|
||||
// Parse the URL to check host
|
||||
parsed, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return true // If can't parse, assume local
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
|
||||
}
|
||||
|
||||
// truncateToolOutput truncates tool output to prevent context overflow.
|
||||
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
|
||||
func truncateToolOutput(output, modelName string) string {
|
||||
var tokenLimit int
|
||||
if isLocalModel(modelName) && isLocalServer() {
|
||||
tokenLimit = localModelTokenLimit
|
||||
} else {
|
||||
tokenLimit = defaultTokenLimit
|
||||
}
|
||||
|
||||
maxChars := tokenLimit * charsPerToken
|
||||
if len(output) > maxChars {
|
||||
return output[:maxChars] + "\n... (output truncated)"
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
|
||||
func waitForOllamaSignin(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get signin URL from initial Whoami call
|
||||
_, err = client.Whoami(ctx)
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
|
||||
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
|
||||
|
||||
// Poll until auth succeeds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
user, whoamiErr := client.Whoami(ctx)
|
||||
if whoamiErr == nil && user != nil && user.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
|
||||
return nil
|
||||
}
|
||||
// Still waiting, show dot
|
||||
fmt.Fprintf(os.Stderr, ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
@@ -134,16 +37,6 @@ type RunOptions struct {
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
Approval *agent.ApprovalManager
|
||||
|
||||
// YoloMode skips all tool approval prompts
|
||||
YoloMode bool
|
||||
|
||||
// LastToolOutput stores the full output of the last tool execution
|
||||
// for Ctrl+O expansion. Updated by Chat(), read by caller.
|
||||
LastToolOutput *string
|
||||
|
||||
// LastToolOutputTruncated stores the truncated version shown inline
|
||||
LastToolOutputTruncated *string
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -184,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -267,58 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for 401 Unauthorized - prompt user to sign in
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the chat request
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
|
||||
}
|
||||
|
||||
// Check for 500 errors (often tool parsing failures) - inform the model
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
|
||||
consecutiveErrors++
|
||||
p.StopAndClear()
|
||||
|
||||
if consecutiveErrors >= 3 {
|
||||
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
|
||||
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
|
||||
// Include both the model's response and the error so it can learn
|
||||
assistantContent := fullResponse.String()
|
||||
if assistantContent == "" {
|
||||
assistantContent = "(empty response)"
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
|
||||
messages = append(messages,
|
||||
api.Message{Role: "user", Content: errorMsg},
|
||||
)
|
||||
|
||||
// Reset state and retry
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
@@ -328,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset consecutive error counter on success
|
||||
consecutiveErrors = 0
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||
break
|
||||
@@ -379,12 +216,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Check approval (uses prefix matching for bash commands)
|
||||
// In yolo mode, skip all approval prompts
|
||||
if opts.YoloMode {
|
||||
if !skipApproval {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||
@@ -418,23 +250,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
@@ -443,34 +258,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
toolSuccess:
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
truncatedOutput := ""
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
|
||||
output = output[:300] + "... (truncated)"
|
||||
}
|
||||
truncatedOutput = output
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
// Store full and truncated output for Ctrl+O toggle
|
||||
if opts.LastToolOutput != nil {
|
||||
*opts.LastToolOutput = toolResult
|
||||
}
|
||||
if opts.LastToolOutputTruncated != nil {
|
||||
*opts.LastToolOutputTruncated = truncatedOutput
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResultForLLM,
|
||||
Content: toolResult,
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
}
|
||||
@@ -648,8 +449,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
@@ -674,11 +474,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
}
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
// Check for OLLAMA_API_KEY for web search
|
||||
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
@@ -690,11 +490,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
// Track last tool output for Ctrl+O toggle
|
||||
var lastToolOutput string
|
||||
var lastToolOutputTruncated string
|
||||
var toolOutputExpanded bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
@@ -707,20 +502,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case errors.Is(err, readline.ErrExpandOutput):
|
||||
// Ctrl+O pressed - toggle between expanded and collapsed tool output
|
||||
if lastToolOutput == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
|
||||
} else if toolOutputExpanded {
|
||||
// Currently expanded, show truncated
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
|
||||
toolOutputExpanded = false
|
||||
} else {
|
||||
// Currently collapsed, show full
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
|
||||
toolOutputExpanded = true
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
@@ -743,9 +524,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
@@ -759,21 +537,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
}
|
||||
// Reset expanded state for new tool execution
|
||||
toolOutputExpanded = false
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "local model without suffix",
|
||||
modelName: "llama3.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "local model with version",
|
||||
modelName: "qwen2.5:7b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
modelName: "",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isLocalModel(tt.modelName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty host (default)",
|
||||
host: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "http://localhost:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "http://127.0.0.1:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "custom port on localhost",
|
||||
host: "http://localhost:8080",
|
||||
expected: true, // localhost is always considered local
|
||||
},
|
||||
{
|
||||
name: "remote host",
|
||||
host: "http://ollama.example.com:11434",
|
||||
expected: true, // has :11434
|
||||
},
|
||||
{
|
||||
name: "remote host different port",
|
||||
host: "http://ollama.example.com:8080",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := isLocalServer()
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolOutput(t *testing.T) {
|
||||
// Create outputs of different sizes
|
||||
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
|
||||
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
|
||||
for i := range localLimitOutput {
|
||||
localLimitOutput[i] = 'a'
|
||||
}
|
||||
for i := range defaultLimitOutput {
|
||||
defaultLimitOutput[i] = 'b'
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
modelName string
|
||||
host string
|
||||
shouldTrim bool
|
||||
expectedLimit int
|
||||
}{
|
||||
{
|
||||
name: "short output local model",
|
||||
output: "hello world",
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output local model - trimmed at 4k",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output remote server - uses 10k limit",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "http://remote.example.com:8080",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := truncateToolOutput(tt.output, tt.modelName)
|
||||
|
||||
if tt.shouldTrim {
|
||||
maxLen := tt.expectedLimit * charsPerToken
|
||||
if len(result) > maxLen+50 { // +50 for the truncation message
|
||||
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
|
||||
}
|
||||
if result == tt.output {
|
||||
t.Error("expected output to be truncated but it wasn't")
|
||||
}
|
||||
} else {
|
||||
if result != tt.output {
|
||||
t.Error("expected output to not be truncated")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
38
x/imagegen/.gitignore
vendored
38
x/imagegen/.gitignore
vendored
@@ -1,38 +0,0 @@
|
||||
# Build directories
|
||||
build/
|
||||
dist/
|
||||
|
||||
# CMake
|
||||
CMakeCache.txt
|
||||
CMakeFiles/
|
||||
cmake_install.cmake
|
||||
Makefile
|
||||
*.cmake
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
*.dSYM/
|
||||
|
||||
# Go
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Python
|
||||
*.npy
|
||||
|
||||
/engine
|
||||
weights
|
||||
outputs
|
||||
|
||||
prompt.txt
|
||||
negative.txt
|
||||
@@ -1,151 +0,0 @@
|
||||
# MLX engine
|
||||
|
||||
This is a small inference engine written in Go using [MLX](https://github.com/ml-explore/mlx), Apple's array framework for machine learning
|
||||
|
||||
## Goals
|
||||
|
||||
1. Implement multimodal runners: in a dedicated runner but eventually to be integrated into Ollama's primary runner.
|
||||
2. Optimizing for image generation memory usage and output speed
|
||||
3. (secondary): implement fast text model inference for gpt-oss, Llama.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**macOS:**
|
||||
|
||||
- macOS 14.0+ (Sonoma or later)
|
||||
- Apple Silicon (M1/M2/M3)
|
||||
- Xcode Command Line Tools
|
||||
|
||||
**Linux (building from source):**
|
||||
|
||||
- NVIDIA GPU (compute capability 7.0+)
|
||||
- CUDA 12.0+ toolkit
|
||||
- cuDNN
|
||||
|
||||
**Linux (prebuilt binaries):**
|
||||
|
||||
- NVIDIA GPU (compute capability 7.0+)
|
||||
- NVIDIA driver 525+ (CUDA runtime libs are bundled)
|
||||
|
||||
**Both:**
|
||||
|
||||
- CMake 3.25+
|
||||
- Go 1.21+
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Build MLX
|
||||
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --parallel
|
||||
cmake --install build
|
||||
```
|
||||
|
||||
This fetches MLX and mlx-c, builds them, and installs to `dist/`:
|
||||
|
||||
- `dist/lib/libmlxc.so` (or `.dylib`) - MLX C bindings
|
||||
- `dist/lib/libmlx.a` - MLX static library
|
||||
- `dist/include/` - Headers (mlx-c, CCCL for CUDA JIT)
|
||||
|
||||
To update MLX version, change `MLX_GIT_TAG` in `CMakeLists.txt` and rebuild.
|
||||
|
||||
### 2. Download a Model
|
||||
|
||||
Download Llama 3.1 8B (or any compatible model) in safetensors format:
|
||||
|
||||
```bash
|
||||
mkdir -p ./weights
|
||||
|
||||
# Example using huggingface-cli
|
||||
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
|
||||
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
|
||||
```
|
||||
|
||||
### 3. Run Inference
|
||||
|
||||
```bash
|
||||
# Build
|
||||
go build ./cmd/engine
|
||||
|
||||
# Text generation
|
||||
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
|
||||
|
||||
# Qwen-Image 2512 (text-to-image)
|
||||
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
|
||||
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
|
||||
|
||||
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
|
||||
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
|
||||
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
|
||||
-steps 8 -seed 42 -output edited.png
|
||||
```
|
||||
|
||||
## Adding a Model
|
||||
|
||||
Use Claude Code with this repo. See `models/CLAUDE.md` for the full guide covering:
|
||||
|
||||
- Porting Python models to Go (forward pass, weight loading)
|
||||
- Component testing with Python reference data
|
||||
- Performance optimization
|
||||
|
||||
Reference implementations: `llama` (LLM), `qwen_image` (image generation), `qwen_image_edit` (image editing)
|
||||
|
||||
## Memory Management
|
||||
|
||||
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
|
||||
|
||||
Instead, arrays are automatically tracked and freed on `Eval()`:
|
||||
|
||||
```go
|
||||
// All arrays are automatically tracked when created
|
||||
x := mlx.Add(a, b)
|
||||
y := mlx.Matmul(x, w)
|
||||
|
||||
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
|
||||
mlx.Eval(y)
|
||||
|
||||
// After copying to CPU, free the array
|
||||
data := y.Data()
|
||||
y.Free()
|
||||
```
|
||||
|
||||
Key points:
|
||||
|
||||
- All created arrays are automatically tracked
|
||||
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
|
||||
- Call `.Free()` when done with an array
|
||||
|
||||
## Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests (tests skip if dependencies missing)
|
||||
go test ./...
|
||||
|
||||
# Run specific model tests
|
||||
go test ./models/qwen_image/...
|
||||
```
|
||||
|
||||
### Model Weights
|
||||
|
||||
Tests require model weights in `./weights/<model-name>/`:
|
||||
|
||||
```
|
||||
weights/
|
||||
├── Qwen-Image-2512/ # Qwen image generation
|
||||
│ ├── text_encoder/
|
||||
│ ├── transformer/
|
||||
│ ├── vae/
|
||||
│ └── tokenizer/
|
||||
├── Llama-3.1-8B/ # LLM
|
||||
└── ...
|
||||
```
|
||||
|
||||
Download models using `huggingface-cli`:
|
||||
|
||||
```bash
|
||||
hf download ./weights/Qwen-Image-2512 --local-dir ./weights/Qwen-Image-2512
|
||||
```
|
||||
154
x/imagegen/cache/cache.go
vendored
154
x/imagegen/cache/cache.go
vendored
@@ -1,154 +0,0 @@
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
type Cache interface {
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
idx int
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
|
||||
var cap int
|
||||
if c.keys != nil {
|
||||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
162
x/imagegen/cache/step.go
vendored
162
x/imagegen/cache/step.go
vendored
@@ -1,162 +0,0 @@
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// StepCache caches layer outputs across diffusion denoising steps.
|
||||
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
// Usage (single-stream):
|
||||
//
|
||||
// cache := NewStepCache(15) // cache first 15 layers
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// output = cache.Get(i) // reuse cached
|
||||
// } else {
|
||||
// output = layer.Forward(input)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, output)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// cache.Free() // cleanup when done
|
||||
//
|
||||
// Usage (dual-stream):
|
||||
//
|
||||
// cache := NewStepCache(15)
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3)
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// imgH, txtH = cache.Get(i), cache.Get2(i)
|
||||
// } else {
|
||||
// imgH, txtH = layer.Forward(imgH, txtH, ...)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, imgH)
|
||||
// cache.Set2(i, txtH)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
type StepCache struct {
|
||||
layers []*mlx.Array // cached layer outputs (stream 1)
|
||||
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
|
||||
constant *mlx.Array // optional constant (e.g., text embeddings)
|
||||
}
|
||||
|
||||
// NewStepCache creates a cache for the given number of layers.
|
||||
func NewStepCache(numLayers int) *StepCache {
|
||||
return &StepCache{
|
||||
layers: make([]*mlx.Array, numLayers),
|
||||
layers2: make([]*mlx.Array, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldRefresh returns true if the cache should be refreshed at this step.
|
||||
// Refresh happens on step 0, interval, 2*interval, etc.
|
||||
func (c *StepCache) ShouldRefresh(step, interval int) bool {
|
||||
return step%interval == 0
|
||||
}
|
||||
|
||||
// Get returns the cached output for a layer, or nil if not cached.
|
||||
func (c *StepCache) Get(layer int) *mlx.Array {
|
||||
if layer < len(c.layers) {
|
||||
return c.layers[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set stores a layer output (stream 1), freeing any previous value.
|
||||
func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers) {
|
||||
if c.layers[layer] != nil {
|
||||
c.layers[layer].Free()
|
||||
}
|
||||
c.layers[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
c.layers2[layer].Free()
|
||||
}
|
||||
c.layers2[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// GetConstant returns the cached constant value.
|
||||
func (c *StepCache) GetConstant() *mlx.Array {
|
||||
return c.constant
|
||||
}
|
||||
|
||||
// SetConstant stores a constant value, freeing any previous value.
|
||||
func (c *StepCache) SetConstant(arr *mlx.Array) {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
}
|
||||
c.constant = arr
|
||||
}
|
||||
|
||||
// Arrays returns all non-nil cached arrays (for pool.Keep).
|
||||
func (c *StepCache) Arrays() []*mlx.Array {
|
||||
var result []*mlx.Array
|
||||
if c.constant != nil {
|
||||
result = append(result, c.constant)
|
||||
}
|
||||
for _, arr := range c.layers {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
for _, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Free releases all cached arrays. Call when generation completes.
|
||||
func (c *StepCache) Free() {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
c.constant = nil
|
||||
}
|
||||
for i, arr := range c.layers {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers[i] = nil
|
||||
}
|
||||
}
|
||||
for i, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers2[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NumLayers returns the number of layers this cache can store.
|
||||
func (c *StepCache) NumLayers() int {
|
||||
return len(c.layers)
|
||||
}
|
||||
@@ -1,357 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
||||
// This handles cases where tokenizers output partial multi-byte sequences.
|
||||
type utf8Streamer struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
||||
func (s *utf8Streamer) Write(text string) string {
|
||||
s.buffer = append(s.buffer, text...)
|
||||
|
||||
// Find the last position that ends with a complete UTF-8 character
|
||||
validLen := 0
|
||||
for i := 0; i < len(s.buffer); {
|
||||
r, size := utf8.DecodeRune(s.buffer[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
// Invalid or incomplete UTF-8 sequence at this position
|
||||
// Check if it could be a valid start of a multi-byte sequence
|
||||
if len(s.buffer)-i < 4 {
|
||||
// Might be incomplete, keep it in buffer
|
||||
break
|
||||
}
|
||||
// Definitely invalid, skip this byte
|
||||
i++
|
||||
validLen = i
|
||||
} else {
|
||||
i += size
|
||||
validLen = i
|
||||
}
|
||||
}
|
||||
|
||||
if validLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := string(s.buffer[:validLen])
|
||||
s.buffer = s.buffer[validLen:]
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
||||
func (s *utf8Streamer) Flush() string {
|
||||
if len(s.buffer) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := string(s.buffer)
|
||||
s.buffer = nil
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
// ChatModel is an optional interface for models that support chat formatting
|
||||
type ChatModel interface {
|
||||
FormatPrompt(prompt string) string
|
||||
}
|
||||
|
||||
// MultimodalModel is for models that support image input
|
||||
type MultimodalModel interface {
|
||||
Model
|
||||
FormatPromptWithImage(prompt string) string
|
||||
ExpandImageTokens(tokens []int32) []int32
|
||||
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
ImageSize() int32 // Returns expected image size for preprocessing
|
||||
}
|
||||
|
||||
// ImageLoader loads and preprocesses an image for multimodal models
|
||||
// Returns nil if path is empty
|
||||
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
||||
|
||||
type input struct {
|
||||
Prompt string
|
||||
Image *mlx.Array // Optional preprocessed image for multimodal models
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
}
|
||||
|
||||
type output struct {
|
||||
Text string
|
||||
Done bool
|
||||
PrefillTokSec float64
|
||||
GenTokSec float64
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
type Decoder struct {
|
||||
model Model
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// For multimodal models with an image, we need to process all tokens together
|
||||
// in the first forward pass so the image embeddings can be inserted properly.
|
||||
// Skip chunking for multimodal prefill.
|
||||
isMultimodal := d.image != nil
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
// Skip chunking for multimodal - process everything in the final step
|
||||
if !isMultimodal {
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling (or all tokens for multimodal)
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
|
||||
var logits *mlx.Array
|
||||
// Use ForwardWithImage if we have an image and model supports it
|
||||
if d.image != nil {
|
||||
if mm, ok := d.model.(MultimodalModel); ok {
|
||||
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
||||
d.image = nil // Only use image for first forward
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
mlx.EnableCompile()
|
||||
wiredLimit := in.WiredLimitGB
|
||||
if wiredLimit <= 0 {
|
||||
wiredLimit = 32 // default 32GB
|
||||
}
|
||||
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
||||
|
||||
temp := in.Temperature
|
||||
if temp < 0 {
|
||||
temp = 0.7
|
||||
}
|
||||
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
||||
prompt = mm.FormatPromptWithImage(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
||||
dec.SetImage(in.Image)
|
||||
} else if cm, ok := m.(ChatModel); ok {
|
||||
prompt = cm.FormatPrompt(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
} else {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
// Step() waits for prefill to complete and returns first token
|
||||
firstToken := dec.step()
|
||||
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
||||
|
||||
genStart := time.Now()
|
||||
maxTokens := max(in.MaxTokens, 100)
|
||||
var genTokens int
|
||||
|
||||
// UTF-8 streamer to handle partial multi-byte characters
|
||||
streamer := &utf8Streamer{}
|
||||
|
||||
// Handle first token
|
||||
genTokens++
|
||||
if tok.IsEOS(firstToken) {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
||||
return nil
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
token := dec.step()
|
||||
genTokens++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
break
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining buffered bytes
|
||||
if text := streamer.Flush(); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
||||
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// saveImageArray saves an MLX array as a PNG image.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func saveImageArray(arr *mlx.Array, path string) error {
|
||||
img, err := arrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return savePNG(img, path)
|
||||
}
|
||||
|
||||
func savePNG(img *image.RGBA, path string) error {
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
arr.Free()
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// stringSlice is a flag type that accumulates multiple values
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
modelPath := flag.String("model", "", "Model directory")
|
||||
prompt := flag.String("prompt", "Hello", "Prompt")
|
||||
|
||||
// Text generation params
|
||||
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
||||
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
// Utility flags
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *modelPath == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle legacy mode flags that aren't unified yet
|
||||
switch {
|
||||
case *zimageFlag:
|
||||
m := &zimage.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
// llm path
|
||||
m, err := load(*modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatal("model does not support image input")
|
||||
}
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
MaxTokens: *maxTokens,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
}
|
||||
if out.Done {
|
||||
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listModelTensors(modelPath string) error {
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range weights.ListTensors() {
|
||||
info, _ := weights.GetTensorInfo(name)
|
||||
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModel builds and evaluates a model using the common load pattern.
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
||||
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
||||
func loadModel[T Model](build func() T, cleanup func()) T {
|
||||
m := build()
|
||||
weights := mlx.Collect(m)
|
||||
cleanup()
|
||||
mlx.Eval(weights...)
|
||||
return m
|
||||
}
|
||||
|
||||
func load(modelPath string) (Model, error) {
|
||||
kind, err := detectModelKind(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect model kind: %w", err)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "gpt_oss":
|
||||
return gpt_oss.Load(modelPath)
|
||||
case "gemma3":
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelKind(modelPath string) (string, error) {
|
||||
indexPath := filepath.Join(modelPath, "model_index.json")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
data, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "zimage", nil
|
||||
}
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err == nil {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(modelPath, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// sampleTopK samples from top-k logits using global random state
|
||||
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
|
||||
neg := mlx.Neg(scaledLogits)
|
||||
indices := mlx.Argpartition(neg, k-1, -1)
|
||||
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
|
||||
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
|
||||
sampled := mlx.RandomCategorical(values, -1, 1)
|
||||
return mlx.Take(topKIdx, sampled, -1)
|
||||
}
|
||||
|
||||
// sampleTopP samples using nucleus sampling with global random state
|
||||
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
|
||||
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
|
||||
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
|
||||
probs := mlx.Softmax(sortedLogits, -1)
|
||||
cumProbs := mlx.Cumsum(probs, -1)
|
||||
mask := mlx.LessScalar(cumProbs, p)
|
||||
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
|
||||
masked := mlx.Where(mask, sortedLogits, negInf)
|
||||
sampled := mlx.RandomCategorical(masked, -1, 1)
|
||||
return mlx.Take(sorted, sampled, -1)
|
||||
}
|
||||
|
||||
// sample samples from logits at the last position
|
||||
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocab)
|
||||
|
||||
if temp == 0 {
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
if topK > 0 && topK < int(vocab) {
|
||||
return sampleTopK(scaled, topK)
|
||||
}
|
||||
if topP > 0 && topP < 1.0 {
|
||||
return sampleTopP(scaled, topP, vocab)
|
||||
}
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
# MLX Memory Management
|
||||
|
||||
| This package will get consolidated with `x/ml/backend/mlx` in the future.
|
||||
|
||||
## Automatic Tracking
|
||||
|
||||
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
||||
|
||||
### API
|
||||
|
||||
```go
|
||||
result := mlx.Matmul(x, w) // arrays automatically tracked
|
||||
mlx.Eval(result) // free non-kept, eval result (auto-kept)
|
||||
```
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
|
||||
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
|
||||
- `array.Free()` - mark array for cleanup on next Eval
|
||||
|
||||
### Loop Pattern
|
||||
|
||||
```go
|
||||
for step := 0; step < maxTokens; step++ {
|
||||
logits := model.Forward(token, caches)
|
||||
oldToken := token
|
||||
token = sample(logits)
|
||||
|
||||
// Keep cache state across iterations
|
||||
for _, c := range caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
|
||||
oldToken.Free() // mark for cleanup
|
||||
mlx.AsyncEval(token) // frees old, evals new
|
||||
}
|
||||
```
|
||||
|
||||
### Notes
|
||||
|
||||
- `Eval()` and `AsyncEval()` auto-keep their outputs
|
||||
- `Free()` marks for cleanup - actual free happens during next Eval
|
||||
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
|
||||
- Arrays created inside compiled closures are managed by MLX, not tracked
|
||||
@@ -1,171 +0,0 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
|
||||
// Destructor for payload (Go handle)
|
||||
extern void goClosureDestructor(void* payload);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// inClosureCallback is set to true during closure callback execution.
|
||||
var inClosureCallback bool
|
||||
var closureCallbackMu sync.Mutex
|
||||
|
||||
// InClosureCallback returns true if we're currently executing inside a closure callback.
|
||||
func InClosureCallback() bool {
|
||||
closureCallbackMu.Lock()
|
||||
defer closureCallbackMu.Unlock()
|
||||
return inClosureCallback
|
||||
}
|
||||
|
||||
// CompiledFunc is a compiled MLX function that can be called efficiently.
|
||||
// All intermediate arrays during execution stay inside MLX - only inputs
|
||||
// and outputs cross the Go boundary.
|
||||
type CompiledFunc struct {
|
||||
closure C.mlx_closure
|
||||
compiled C.mlx_closure
|
||||
}
|
||||
|
||||
// ClosureFunc is the signature for functions that can be compiled.
|
||||
// It takes a slice of input arrays and returns a slice of output arrays.
|
||||
type ClosureFunc func(inputs []*Array) []*Array
|
||||
|
||||
// Compile compiles a Go function into an optimized MLX closure.
|
||||
// The function is traced once during compilation, then subsequent calls
|
||||
// run the optimized graph without creating Go intermediate arrays.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
// a, b := inputs[0], inputs[1]
|
||||
// c := mlx.Add(a, b)
|
||||
// d := mlx.Mul(c, c)
|
||||
// return []*mlx.Array{d}
|
||||
// })
|
||||
// defer compiled.Free()
|
||||
//
|
||||
// result := compiled.Call(x, y)[0]
|
||||
func Compile(fn ClosureFunc) *CompiledFunc {
|
||||
return CompileShapeless(fn, false)
|
||||
}
|
||||
|
||||
// CompileShapeless compiles with optional shapeless mode.
|
||||
// If shapeless=true, the function works for any input shape after tracing.
|
||||
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
|
||||
// Create a cgo.Handle to prevent the Go function from being GC'd
|
||||
handle := cgo.NewHandle(fn)
|
||||
|
||||
// Create the closure from the Go callback
|
||||
closure := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.goClosureCallback),
|
||||
unsafe.Pointer(handle),
|
||||
(*[0]byte)(C.goClosureDestructor),
|
||||
)
|
||||
|
||||
// Compile the closure
|
||||
compiled := C.mlx_closure_new()
|
||||
C.mlx_compile(&compiled, closure, C.bool(shapeless))
|
||||
|
||||
return &CompiledFunc{
|
||||
closure: closure,
|
||||
compiled: compiled,
|
||||
}
|
||||
}
|
||||
|
||||
// Call invokes the compiled function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
// Pack inputs into vector
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
for _, arr := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, arr.c)
|
||||
}
|
||||
|
||||
// Apply compiled closure
|
||||
outputVec := C.mlx_vector_array_new()
|
||||
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
|
||||
C.mlx_vector_array_free(inputVec)
|
||||
|
||||
// Unpack outputs
|
||||
numOutputs := int(C.mlx_vector_array_size(outputVec))
|
||||
outputs := make([]*Array, numOutputs)
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
|
||||
outputs[i] = newArray(arr)
|
||||
}
|
||||
C.mlx_vector_array_free(outputVec)
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// CallEval invokes the compiled function and evaluates the results.
|
||||
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
|
||||
outputs := cf.Call(inputs...)
|
||||
Eval(outputs...)
|
||||
return outputs
|
||||
}
|
||||
|
||||
// Free releases the compiled function resources.
|
||||
func (cf *CompiledFunc) Free() {
|
||||
C.mlx_closure_free(cf.compiled)
|
||||
C.mlx_closure_free(cf.closure)
|
||||
}
|
||||
|
||||
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
|
||||
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
|
||||
func borrowArray(array C.mlx_array) *Array {
|
||||
return &Array{c: array}
|
||||
}
|
||||
|
||||
//export goClosureCallback
|
||||
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
// Set flag to disable AddCleanup during callback
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = true
|
||||
closureCallbackMu.Unlock()
|
||||
defer func() {
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = false
|
||||
closureCallbackMu.Unlock()
|
||||
}()
|
||||
|
||||
// Recover the Go function from the handle
|
||||
handle := cgo.Handle(payload)
|
||||
fn := handle.Value().(ClosureFunc)
|
||||
|
||||
// Convert input vector to Go slice - use borrowArray since MLX owns these
|
||||
numInputs := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, numInputs)
|
||||
for i := 0; i < numInputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, input, C.size_t(i))
|
||||
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
|
||||
}
|
||||
|
||||
// Call the Go function
|
||||
outputs := fn(inputs)
|
||||
|
||||
// Build output vector
|
||||
*res = C.mlx_vector_array_new()
|
||||
for _, arr := range outputs {
|
||||
C.mlx_vector_array_append_value(*res, arr.c)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goClosureDestructor
|
||||
func goClosureDestructor(payload unsafe.Pointer) {
|
||||
handle := cgo.Handle(payload)
|
||||
handle.Delete()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,277 +0,0 @@
|
||||
# Model Implementation Guide
|
||||
|
||||
See `README.md` for memory management (critical for Go + MLX).
|
||||
|
||||
## Phase 1: Import & Forward Pass
|
||||
|
||||
- Read Python reference implementation (PyTorch/Transformers)
|
||||
- Create Go struct mirroring layer hierarchy
|
||||
- Implement weight loading from safetensors (see `safetensors.go`)
|
||||
- Port forward pass layer-by-layer, bottom-up
|
||||
- For tokenizers: check if BPE (`bpe.go`) or custom needed
|
||||
|
||||
**Key files to reference:** `llama` (dense LLM), `gpt_oss` (MoE LLM), `zimage` (image generation), `qwen_image_edit` (image editing)
|
||||
|
||||
### Vision Models: Image Preprocessing
|
||||
|
||||
When implementing vision models (image-to-text, image editing, etc.), image preprocessing must match Python exactly. Common pitfalls:
|
||||
|
||||
1. **Resolution constraints**: Many vision models use `min_pixels` and `max_pixels` to constrain image size, not a fixed target area. Check the Python processor's `smart_resize` logic.
|
||||
|
||||
2. **Patch alignment**: Images must be resized to multiples of `factor = patch_size * spatial_merge_size` (e.g., 14 \* 2 = 28 for Qwen2.5-VL).
|
||||
|
||||
3. **Normalization**: Vision encoders use ImageNet stats (mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), not simple [-1, 1] scaling.
|
||||
|
||||
4. **Temporal dimension**: Video/image models may expect a temporal dimension (e.g., `[B, T, C, H, W]`). For single images, duplicate frames if `temporal_patch_size > 1`.
|
||||
|
||||
**Verification**: Always compare Go preprocessed image shape and statistics against Python to catch sizing mismatches early.
|
||||
|
||||
### Tokenizer & Chat Templates
|
||||
|
||||
Most instruction-tuned models require:
|
||||
|
||||
1. **BOS token**: Added at start of input (token ID 2 for most models)
|
||||
2. **Chat template**: Wraps user prompt in model-specific format
|
||||
|
||||
**Common chat templates:**
|
||||
|
||||
| Model | Format |
|
||||
| ------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| Llama 3 | `<\|begin_of_text\|><\|start_header_id\|>user<\|end_header_id\|>\n{prompt}<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>\n` |
|
||||
| Gemma 3 | `<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n` |
|
||||
| Qwen | `<\|im_start\|>user\n{prompt}<\|im_end\|>\n<\|im_start\|>assistant\n` |
|
||||
|
||||
**Checking tokenization:**
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate && python3 -c "
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained('./weights/model-name')
|
||||
tokens = tok.encode('Hello', add_special_tokens=True)
|
||||
print('Tokens:', tokens)
|
||||
print('Decoded:', [tok.decode([t]) for t in tokens])
|
||||
"
|
||||
```
|
||||
|
||||
### Text Model Checklist
|
||||
|
||||
Before moving to vision components, ensure the text model is fully working:
|
||||
|
||||
1. **Sliding window cache**: Some models (Gemma 3, GPT-OSS) use sliding window attention on certain layers. Use `cache.NewRotatingKVCache(windowSize)` for those layers, not `cache.NewKVCache()`. Check config for `sliding_window` and `sliding_window_pattern`.
|
||||
|
||||
2. **Unicode/UTF-8 decoding**: If output shows garbled characters like `Â` before spaces, the tokenizer's byte-level encoding isn't being decoded properly. Check `Decode()` handles UTF-8 byte sequences correctly.
|
||||
|
||||
3. **EOS tokens from vocabulary**: Don't hardcode EOS token IDs. The tokenizer should extract them from `added_tokens` in `tokenizer.json`. Multiple EOS tokens are common (e.g., Gemma has both `<eos>` and `<end_of_turn>`).
|
||||
|
||||
4. **Chat template**: Instruction-tuned models need chat formatting. Test with and without to ensure the model responds coherently.
|
||||
|
||||
5. **Compare with reference**: Always test against `mlx_lm.generate` with same prompt and `--temp 0` to verify outputs match.
|
||||
|
||||
## Phase 2: Correctness Testing
|
||||
|
||||
Run the model and look at the output. Make sure it outputs something coherent.
|
||||
|
||||
To compare correctness, add hooks to the python model and compare the output with debug statements in Go.
|
||||
|
||||
## Phase 3: Memory Verification
|
||||
|
||||
After loading, verify peak memory is close to final model size:
|
||||
|
||||
```bash
|
||||
# Run and check peak vs active memory
|
||||
/tmp/engine -model ./weights/MyModel -steps 1 2>&1 | grep -E "(peak|GB)"
|
||||
```
|
||||
|
||||
**Expected:** Peak should be ~1.1x final size (small overhead is OK). If peak is 2-3x final size, you have a memory problem.
|
||||
|
||||
### Checking Weight Dtypes
|
||||
|
||||
```bash
|
||||
# Check dtype of weights in safetensors files
|
||||
python3 -c "
|
||||
from safetensors import safe_open
|
||||
f = safe_open('model.safetensors', 'pt')
|
||||
for k in list(f.keys())[:5]:
|
||||
print(k, f.get_tensor(k).dtype)
|
||||
"
|
||||
```
|
||||
|
||||
### f32 Weights Need Special Handling
|
||||
|
||||
If weights are f32 but model runs in bf16, use `GetTensorBF16()` instead of `GetTensor()`:
|
||||
|
||||
- `GetTensor()` uses MLX's native loader (loads all tensors from file at once)
|
||||
- `GetTensorBF16()` loads one tensor at a time, converts to bf16, frees f32 immediately
|
||||
|
||||
This prevents peak memory from being 2x model size during loading.
|
||||
|
||||
## Phase 4: Performance
|
||||
|
||||
### Evaluation Strategy
|
||||
|
||||
- Call `mlx.Eval()` once per token/step, not inside loops
|
||||
- Use `mlx.AsyncEval()` to pipeline: build next step's graph while current executes
|
||||
- Never call `mlx.Eval()` inside attention or MLP - batch it at the end
|
||||
|
||||
### Fast Operations (Already Built-in)
|
||||
|
||||
These Go functions use MLX's fast fused kernels internally:
|
||||
|
||||
- `mlx.RMSNorm(x, weight, eps)` → uses `mlx_fast_rms_norm`
|
||||
- `mlx.RoPE(x, dims, traditional, base, scale, offset)` → uses `mlx_fast_rope`
|
||||
- `mlx.ScaledDotProductAttention(q, k, v, scale, causalMask)` → uses `mlx_fast_scaled_dot_product_attention`
|
||||
|
||||
### Type Promotion Gotchas
|
||||
|
||||
- `mlx.Mul(bf16Array, mlx.Full(shape, 2.0, mlx.Float32))` → upcasts everything to f32
|
||||
- Use `mlx.MulScalar(bf16Array, 2.0)` to preserve dtype (if available), or ensure scalar arrays match input dtype
|
||||
|
||||
### Profiling
|
||||
|
||||
- Use `mactop` to check GPU utilization - should be ~100%
|
||||
- If low, bottleneck is likely Go code (tokenization, data prep), not MLX
|
||||
- Use `pprof` for CPU profiling to find Go-side overhead (CGO calls, tokenization, etc.)
|
||||
- Use Metal debugger for kernel-level profiling (see docs/performance.md)
|
||||
- Profile with `time.Since()` around major blocks
|
||||
- Compare tok/s against reference (llama.cpp, MLX-LM)
|
||||
|
||||
## Phase 5: Polish
|
||||
|
||||
- Remove debug prints
|
||||
- Add proper error handling
|
||||
- Document config.json fields used
|
||||
|
||||
## Tips
|
||||
|
||||
- MLX is lazy; call `Eval()` only when you need values
|
||||
- Check `model.safetensors.index.json` for weight→file mapping
|
||||
|
||||
## Common Gotchas
|
||||
|
||||
### MLX Transpose requires Contiguous
|
||||
|
||||
`mlx.Transpose()` returns a view with modified strides - calling `Data()` returns the original memory layout. Always follow with `mlx.Contiguous()` if you need correct data ordering:
|
||||
|
||||
```go
|
||||
// Wrong - Data() returns original layout
|
||||
x = mlx.Transpose(x, 0, 2, 3, 4, 1)
|
||||
data := x.Data() // Bug: data is in wrong order
|
||||
|
||||
// Correct
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
data := x.Data() // Data is in transposed order
|
||||
```
|
||||
|
||||
### Missing Biases in Weight Loading
|
||||
|
||||
Python layers often have optional biases. Check the safetensors files for bias tensors:
|
||||
|
||||
```bash
|
||||
python3 -c "from safetensors import safe_open; f=safe_open('model.safetensors','pt'); print([k for k in f.keys() if 'bias' in k])"
|
||||
```
|
||||
|
||||
### Don't Spam ClearCache() or Eval()
|
||||
|
||||
- `mlx.ClearCache()` clears the GPU cache but doesn't free arrays - it has minimal effect on memory. Don't call it repeatedly.
|
||||
- `mlx.Eval()` forces synchronous evaluation and frees non-kept arrays. Call it once per step/token, not inside loops.
|
||||
|
||||
### Lazy Eval and Free() - The Critical Pattern
|
||||
|
||||
MLX arrays are lazy - operations build a graph, actual computation happens at `Eval()`. This has a critical implication for `Free()`:
|
||||
|
||||
```go
|
||||
// WRONG: Lazy array references freed input
|
||||
func BadForward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Add(compute(x), x) // Returns lazy array referencing x
|
||||
}
|
||||
|
||||
func Caller() {
|
||||
result := BadForward(input)
|
||||
input.Free() // Frees input, but result still references it!
|
||||
mlx.Eval(result) // CRASH: "expected a non-empty mlx_array"
|
||||
}
|
||||
|
||||
// CORRECT: Eval before caller can free inputs
|
||||
func GoodForward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.Add(compute(x), x)
|
||||
mlx.Eval(out) // Materialize before returning
|
||||
return out
|
||||
}
|
||||
```
|
||||
|
||||
**Rule**: If your function returns an array that references its input (residual connections, skip connections), you MUST `Eval()` before returning - otherwise the caller may free the input while the result still needs it.
|
||||
|
||||
**Debugging**: Errors like "expected a non-empty mlx_array" at Eval time often mean a tensor was freed while still referenced by a lazy graph. Add logging BEFORE the Free() calls to find which one, not inside the lazy operations.
|
||||
|
||||
### Data() and DataInt32() Trigger Eval
|
||||
|
||||
Calling `.Data()` or `.DataInt32()` on an array does an implicit `Eval()`, which frees any un-eval'd arrays:
|
||||
|
||||
```go
|
||||
// WRONG: tokenArray gets freed when we eval image
|
||||
tokenArray := mlx.NewArrayInt32(tokens, shape)
|
||||
image := processImage(path) // This evals image internally
|
||||
mlx.Eval(image) // This frees tokenArray!
|
||||
|
||||
tokenData := tokenArray.DataInt32() // CRASH: tokenArray was freed
|
||||
|
||||
// CORRECT: Eval arrays you need to keep before other evals
|
||||
tokenArray := mlx.NewArrayInt32(tokens, shape)
|
||||
mlx.Eval(tokenArray) // Materialize it first
|
||||
image := processImage(path)
|
||||
|
||||
tokenData := tokenArray.DataInt32() // Works fine
|
||||
```
|
||||
|
||||
**Rule**: Before calling any function that might do an `Eval()` internally, make sure to `Eval()` any arrays you'll need later. When passing arrays to model forward functions, eval them first if they were just created.
|
||||
|
||||
### Diffusers Pipeline vs Scheduler Defaults
|
||||
|
||||
Diffusers pipelines often pass custom parameters that override scheduler defaults. When writing tests, match what the **pipeline** does, not the raw scheduler:
|
||||
|
||||
```python
|
||||
# Scheduler default (when no sigmas passed):
|
||||
# sigmas from 1.0 to 1/1000 = 0.001
|
||||
|
||||
# But pipeline passes custom sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
scheduler.set_timesteps(sigmas=sigmas, ...) # 1.0 to 1/30 for 30 steps
|
||||
```
|
||||
|
||||
Always check the pipeline source to see what parameters it passes to components.
|
||||
|
||||
### Diffusion Models: Timestep Scaling
|
||||
|
||||
Diffusion transformers use sinusoidal timestep embeddings with internal scaling. **Critical**: Check what the pipeline actually passes to the transformer, not just what the scheduler stores.
|
||||
|
||||
**Common pattern in diffusers (tricky!):**
|
||||
|
||||
- `scheduler.sigmas` = values in [0, 1] range (e.g., 1.0, 0.608, 0.02)
|
||||
- `scheduler.timesteps` = sigmas × 1000 (e.g., 1000, 608, 20)
|
||||
- **BUT** the pipeline often divides by 1000 before passing to transformer: `timestep=t / 1000`
|
||||
- Transformer's `Timesteps` class has `scale=1000`, multiplying input by 1000
|
||||
- Net effect: transformer receives sigma (0.608), scales to 608
|
||||
|
||||
**Verification - check the actual pipeline source:**
|
||||
|
||||
```bash
|
||||
grep -A2 "timestep=" .venv/.../pipeline_*.py
|
||||
# Look for: timestep=timestep / 1000 ← pipeline normalizes!
|
||||
```
|
||||
|
||||
**Go approach (skip the multiply/divide dance):**
|
||||
|
||||
```go
|
||||
// Store sigmas directly as timesteps - equivalent to Python's
|
||||
// scheduler.timesteps / 1000 that the pipeline passes to transformer
|
||||
s.Timesteps[i] = sigmas[i] // 0.608
|
||||
// Transformer does: 0.608 * 1000 = 608 ✓
|
||||
```
|
||||
|
||||
**Symptoms of wrong timestep scaling:**
|
||||
|
||||
- Noise predictions have wrong magnitude (off by orders of magnitude)
|
||||
- Output images are completely noisy/corrupted or have extreme contrast
|
||||
- Latents diverge from Python after first denoising step
|
||||
|
||||
**Key lesson:** Don't assume scheduler.timesteps is what the transformer receives - always check the pipeline's forward pass for any normalization.
|
||||
@@ -1,612 +0,0 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds configuration for the text model
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
// Computed fields
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// TextModel is the Gemma 3 text-only model
|
||||
type TextModel struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*DecoderLayer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
|
||||
NormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
InputNormScaled *mlx.Array `weight:"-"`
|
||||
PostAttnNormScaled *mlx.Array `weight:"-"`
|
||||
PreFFNormScaled *mlx.Array `weight:"-"`
|
||||
PostFFNormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
// Whether this layer uses sliding window attention
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
QNormScaled *mlx.Array `weight:"-"`
|
||||
KNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// LoadText loads the text-only Gemma 3 model
|
||||
func LoadText(modelPath string) (*TextModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute scale
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Set defaults if not specified
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize layer metadata
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for output
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
|
||||
precomputeGemmaScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
|
||||
// This avoids creating temporary arrays on every forward pass
|
||||
func precomputeGemmaScaledWeights(m *TextModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
// Eval all the precomputed weights
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
|
||||
// isLayerSliding determines if a layer uses sliding window attention
|
||||
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false // No sliding window
|
||||
}
|
||||
// Layer is full attention if (layerIdx + 1) % pattern == 0
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass
|
||||
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
// Get embeddings and scale by sqrt(hidden_size)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// Forward runs a decoder layer
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
// Pre-attention norm (use precomputed scaled weight)
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Attention
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
|
||||
// Post-attention norm and residual
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-FFN norm
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// MLP
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
|
||||
// Post-FFN norm and residual
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
// Forward runs attention with Q/K normalization
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization after reshaping (use precomputed scaled weight)
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Apply RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// Repeat K/V for GQA if needed
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
|
||||
var compiledGeluApprox *mlx.CompiledFunc
|
||||
|
||||
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
|
||||
func getCompiledGeluApprox() *mlx.CompiledFunc {
|
||||
if compiledGeluApprox == nil {
|
||||
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApproxImpl(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGeluApprox
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU approximation (tanh variant)
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApproxImpl(x *mlx.Array) *mlx.Array {
|
||||
// Constants
|
||||
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
|
||||
const coeff = 0.044715
|
||||
|
||||
// x^3
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
// x + 0.044715 * x^3
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
// tanh(...)
|
||||
tanh := mlx.Tanh(scaled)
|
||||
// 1 + tanh(...)
|
||||
onePlusTanh := mlx.AddScalar(tanh, 1.0)
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
|
||||
}
|
||||
|
||||
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
|
||||
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
|
||||
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
|
||||
// Gemma uses (1 + weight) instead of weight
|
||||
scaledWeight := mlx.AddScalar(weight, 1.0)
|
||||
return mlx.RMSNorm(x, scaledWeight, eps)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *TextModel) NumLayers() int { return len(m.Layers) }
|
||||
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
|
||||
|
||||
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
|
||||
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template to a prompt
|
||||
func (m *TextModel) FormatPrompt(prompt string) string {
|
||||
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
// Use rotating cache for sliding window layers
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
// Use regular cache for global attention layers
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// Config holds config for the full multimodal model
|
||||
type Config struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
VisionConfig VisionConfig `json:"vision_config"`
|
||||
|
||||
// Image token config (from config.json)
|
||||
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
|
||||
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
|
||||
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
|
||||
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
|
||||
}
|
||||
|
||||
// Model is the full Gemma 3 multimodal model
|
||||
type Model struct {
|
||||
VisionTower *VisionTower `weight:"vision_tower"`
|
||||
Projector *MultiModalProjector `weight:"multi_modal_projector"`
|
||||
TextModel *TextModel `weight:"language_model"`
|
||||
Config *Config
|
||||
tok *tokenizer.Tokenizer
|
||||
}
|
||||
|
||||
// Load loads the full multimodal Gemma 3 model
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Set defaults for text config (multimodal config often has incomplete text_config)
|
||||
// These defaults match transformers.Gemma3TextConfig defaults
|
||||
tc := &cfg.TextConfig
|
||||
if tc.HeadDim == 0 {
|
||||
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
|
||||
}
|
||||
if tc.NumAttentionHeads == 0 {
|
||||
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
|
||||
tc.NumAttentionHeads = 8
|
||||
}
|
||||
if tc.NumKeyValueHeads == 0 {
|
||||
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
|
||||
tc.NumKeyValueHeads = 4
|
||||
}
|
||||
if tc.VocabSize == 0 {
|
||||
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
|
||||
}
|
||||
if tc.RopeTheta == 0 {
|
||||
tc.RopeTheta = 1000000
|
||||
}
|
||||
if tc.RopeLocalBaseFreq == 0 {
|
||||
tc.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if tc.RMSNormEps == 0 {
|
||||
tc.RMSNormEps = 1e-6
|
||||
}
|
||||
if tc.SlidingWindowPattern == 0 {
|
||||
tc.SlidingWindowPattern = 6
|
||||
}
|
||||
if tc.MaxPositionEmbeddings == 0 {
|
||||
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
|
||||
}
|
||||
|
||||
// Compute text model scale
|
||||
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
|
||||
|
||||
// Set defaults for image token config
|
||||
if cfg.BOITokenIndex == 0 {
|
||||
cfg.BOITokenIndex = 255999 // <start_of_image>
|
||||
}
|
||||
if cfg.EOITokenIndex == 0 {
|
||||
cfg.EOITokenIndex = 256000 // <end_of_image>
|
||||
}
|
||||
if cfg.ImageTokenIndex == 0 {
|
||||
cfg.ImageTokenIndex = 262144 // <image_soft_token>
|
||||
}
|
||||
if cfg.MMTokensPerImage == 0 {
|
||||
cfg.MMTokensPerImage = 256
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
VisionTower: &VisionTower{
|
||||
Embeddings: &VisionEmbeddings{},
|
||||
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
|
||||
Config: &cfg.VisionConfig,
|
||||
},
|
||||
Projector: &MultiModalProjector{},
|
||||
TextModel: &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
|
||||
TextConfig: &cfg.TextConfig,
|
||||
},
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize text layer metadata
|
||||
for i := range m.TextModel.Layers {
|
||||
m.TextModel.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize vision encoder layers
|
||||
for i := range m.VisionTower.Encoder {
|
||||
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for text output
|
||||
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
|
||||
m.TextModel.tok = tok
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeGemmaScaledWeights(m.TextModel)
|
||||
|
||||
// Precompute projector's scaled weight
|
||||
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
|
||||
mlx.Eval(m.Projector.SoftEmbNormScaled)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the text-only forward pass
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.TextModel.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
// ForwardWithImage runs the multimodal forward pass
|
||||
// tokens: [B, L] input token IDs (with image placeholder tokens)
|
||||
// image: [B, H, W, C] preprocessed image tensor
|
||||
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
cfg := m.Config.TextConfig
|
||||
|
||||
// Find image token position FIRST before any eval that might free tokens
|
||||
imageStartPos := int32(-1)
|
||||
if image != nil && B == 1 {
|
||||
tokenData := tokens.DataInt32() // This evals tokens
|
||||
for i, t := range tokenData {
|
||||
if t == m.Config.ImageTokenIndex {
|
||||
imageStartPos = int32(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get text embeddings and scale
|
||||
h := m.TextModel.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
|
||||
|
||||
// Process image if provided
|
||||
if image != nil && imageStartPos >= 0 {
|
||||
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
|
||||
visionFeatures := m.VisionTower.Forward(image)
|
||||
|
||||
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
|
||||
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
|
||||
|
||||
// Eval h and imageEmbeds together so neither gets freed
|
||||
mlx.Eval(h, imageEmbeds)
|
||||
|
||||
// Cast imageEmbeds to match text embeddings dtype (bf16)
|
||||
if imageEmbeds.Dtype() != h.Dtype() {
|
||||
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
|
||||
mlx.Eval(imageEmbeds)
|
||||
}
|
||||
|
||||
// Insert image embeddings at the known position
|
||||
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
|
||||
}
|
||||
|
||||
// Run through text model layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
|
||||
// at a known position (to avoid re-scanning tokens after eval)
|
||||
// textEmbeds: [B, L, hidden_size] text embeddings
|
||||
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
|
||||
// startPos: starting position of image tokens in the sequence
|
||||
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
|
||||
numImageTokens := imageEmbeds.Shape()[1]
|
||||
L := textEmbeds.Shape()[1]
|
||||
|
||||
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
|
||||
afterStart := startPos + numImageTokens
|
||||
|
||||
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
|
||||
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
|
||||
|
||||
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
|
||||
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
|
||||
|
||||
// Concatenate: before + imageEmbeds + after along axis 1
|
||||
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
|
||||
}
|
||||
|
||||
// Interface methods for Model
|
||||
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
|
||||
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
|
||||
|
||||
// FormatPrompt applies the Gemma 3 multimodal chat template
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
|
||||
func (m *Model) FormatPromptWithImage(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
|
||||
// Input tokens containing boi_token (255999) are expanded to:
|
||||
// boi_token + 256 * image_token + eoi_token
|
||||
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
|
||||
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
|
||||
|
||||
for _, t := range tokens {
|
||||
if t == m.Config.BOITokenIndex {
|
||||
// Expand: boi + 256 * image_token + eoi
|
||||
result = append(result, m.Config.BOITokenIndex)
|
||||
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
|
||||
result = append(result, m.Config.ImageTokenIndex)
|
||||
}
|
||||
result = append(result, m.Config.EOITokenIndex)
|
||||
} else {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// ProcessImage loads and preprocesses an image for the vision tower
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
|
||||
return ProcessImageData(img, imageSize)
|
||||
}
|
||||
|
||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
// Resize to target size using bilinear interpolation
|
||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert to float32 array [H, W, C] and normalize
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
||||
data := make([]float32, imageSize*imageSize*3)
|
||||
idx := 0
|
||||
for y := int32(0); y < imageSize; y++ {
|
||||
for x := int32(0); x < imageSize; x++ {
|
||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||
// RGBA returns 16-bit values, convert to 8-bit
|
||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
|
||||
// Create MLX array [1, H, W, C] for NHWC layout
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
||||
return arr, nil
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// MultiModalProjector projects vision features to text embedding space
|
||||
type MultiModalProjector struct {
|
||||
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
||||
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
||||
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// Forward projects vision features to text space
|
||||
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
||||
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
||||
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
||||
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
||||
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
||||
B := visionFeatures.Shape()[0]
|
||||
visionHidden := visionFeatures.Shape()[2]
|
||||
|
||||
// Reshape to [B, 64, 64, hidden]
|
||||
gridSize := int32(64) // sqrt(4096)
|
||||
pooledSize := int32(16) // 64/4
|
||||
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
||||
|
||||
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
||||
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
||||
|
||||
// Average over pooling dimensions (axes 2 and 4)
|
||||
h = mlx.Mean(h, 4, false)
|
||||
h = mlx.Mean(h, 2, false)
|
||||
|
||||
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
||||
numTokens := pooledSize * pooledSize
|
||||
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
||||
|
||||
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
||||
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
||||
|
||||
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
||||
return mlx.Linear(h, p.InputProjection)
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// VisionConfig holds configuration for the SigLIP vision tower
|
||||
type VisionConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
ImageSize int32 `json:"image_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
}
|
||||
|
||||
// VisionTower is the SigLIP vision encoder
|
||||
type VisionTower struct {
|
||||
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
||||
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
||||
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
||||
Config *VisionConfig
|
||||
}
|
||||
|
||||
// VisionEmbeddings handles patch and position embeddings
|
||||
type VisionEmbeddings struct {
|
||||
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
||||
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
||||
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
||||
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is a single transformer encoder layer
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
||||
Attention *VisionAttention `weight:"self_attn"`
|
||||
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
||||
MLP *VisionMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// VisionAttention implements multi-head self-attention
|
||||
type VisionAttention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OutProj *nn.Linear `weight:"out_proj"`
|
||||
}
|
||||
|
||||
// VisionMLP is the feed-forward network
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `weight:"fc1"`
|
||||
FC2 *nn.Linear `weight:"fc2"`
|
||||
}
|
||||
|
||||
// Forward runs the vision tower on preprocessed images
|
||||
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
||||
// Output: [B, num_patches, hidden_size]
|
||||
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
||||
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
||||
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
||||
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
||||
|
||||
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
||||
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
||||
h = mlx.Add(h, bias)
|
||||
|
||||
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
||||
B := h.Shape()[0]
|
||||
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
||||
hidden := h.Shape()[3]
|
||||
numPatches := gridH * gridW
|
||||
h = mlx.Reshape(h, B, numPatches, hidden)
|
||||
|
||||
// Add position embeddings
|
||||
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
||||
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
||||
h = mlx.Add(h, posEmbed)
|
||||
|
||||
// Encoder layers
|
||||
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
for _, layer := range v.Encoder {
|
||||
h = layer.Forward(h, v.Config, scale)
|
||||
}
|
||||
|
||||
// Final layer norm
|
||||
h = v.PostLayerNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward runs a vision encoder layer
|
||||
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
// Pre-norm attention
|
||||
h := l.LayerNorm1.Forward(x)
|
||||
h = l.Attention.Forward(h, cfg, scale)
|
||||
x = mlx.Add(x, h)
|
||||
|
||||
// Pre-norm MLP
|
||||
h = l.LayerNorm2.Forward(x)
|
||||
h = l.MLP.Forward(h)
|
||||
return mlx.Add(x, h)
|
||||
}
|
||||
|
||||
// Forward runs multi-head self-attention
|
||||
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
B, L := x.Shape()[0], x.Shape()[1]
|
||||
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention (no causal mask for vision)
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
||||
|
||||
return a.OutProj.Forward(out)
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU activation
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := mlx.GELU(m.FC1.Forward(x))
|
||||
return m.FC2.Forward(h)
|
||||
}
|
||||
@@ -1,485 +0,0 @@
|
||||
package gpt_oss
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds YaRN or other RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
NumLocalExperts int32 `json:"num_local_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
SwiGLULimit float32 `json:"swiglu_limit"`
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
||||
YarnFreqs *mlx.Array // computed
|
||||
YarnMscale float32
|
||||
}
|
||||
|
||||
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
||||
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
||||
// with clipping: gate to [None, limit], up to [-limit, limit]
|
||||
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
||||
// Clip gate to [None, limit]
|
||||
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
||||
|
||||
// Clip up to [-limit, limit]
|
||||
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
||||
|
||||
// glu_scaled = alpha * gate_clipped
|
||||
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
||||
|
||||
// sig = sigmoid(glu_scaled)
|
||||
sig := mlx.Sigmoid(gluScaled)
|
||||
|
||||
// out_glu = gate_clipped * sig
|
||||
outGlu := mlx.Mul(gateClipped, sig)
|
||||
|
||||
// result = out_glu * (up_clipped + 1)
|
||||
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
||||
}
|
||||
|
||||
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
const alpha float32 = 1.702
|
||||
const limit float32 = 7.0
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
||||
}, true) // shapeless=true so it works for any input size
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
||||
// Based on mlx-lm's YarnRoPE implementation
|
||||
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
||||
// yarn_find_correction_dim
|
||||
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
||||
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
||||
}
|
||||
|
||||
// yarn_find_correction_range
|
||||
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
||||
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
||||
if low < 0 {
|
||||
low = 0
|
||||
}
|
||||
if high > int(dims)-1 {
|
||||
high = int(dims) - 1
|
||||
}
|
||||
|
||||
// yarn_get_mscale
|
||||
yarnGetMscale := func(scale, mscale float64) float64 {
|
||||
if scale <= 1 {
|
||||
return 1.0
|
||||
}
|
||||
return 0.1*mscale*math.Log(scale) + 1.0
|
||||
}
|
||||
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
||||
|
||||
// Compute frequencies
|
||||
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
||||
// freq_inter = scaling_factor * freq_extra
|
||||
halfDims := dims / 2
|
||||
freqData := make([]float32, halfDims)
|
||||
for i := int32(0); i < halfDims; i++ {
|
||||
exp := float64(2*i) / float64(dims)
|
||||
freqExtra := math.Pow(float64(base), exp)
|
||||
freqInter := float64(scalingFactor) * freqExtra
|
||||
|
||||
// linear ramp mask
|
||||
var freqMask float64
|
||||
if low == high {
|
||||
freqMask = 0.0
|
||||
} else {
|
||||
t := (float64(i) - float64(low)) / float64(high-low)
|
||||
if t < 0 {
|
||||
t = 0
|
||||
}
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
freqMask = 1.0 - t
|
||||
}
|
||||
|
||||
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
||||
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
||||
}
|
||||
|
||||
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
||||
}
|
||||
|
||||
// initYarn initializes YaRN RoPE if configured
|
||||
func (a *Attention) initYarn(cfg *Config) {
|
||||
a.YarnMscale = 1.0
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
||||
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
||||
cfg.HeadDim,
|
||||
cfg.RopeTheta,
|
||||
cfg.RopeScaling.Factor,
|
||||
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
||||
cfg.RopeScaling.BetaFast,
|
||||
cfg.RopeScaling.BetaSlow,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
if a.YarnFreqs != nil {
|
||||
if a.YarnMscale != 1.0 {
|
||||
q = mlx.MulScalar(q, a.YarnMscale)
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
} else {
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// CreateSlidingWindowMask creates a causal mask with sliding window
|
||||
// Mirrors mlx-lm's create_causal_mask with window_size
|
||||
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
||||
// Build mask aligned to actual cache length (may be rotated)
|
||||
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
||||
// linds covers new queries: [queryStart, queryStart+seqLen)
|
||||
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
||||
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
||||
|
||||
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
||||
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
||||
|
||||
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
||||
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
||||
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
||||
|
||||
return mlx.LogicalAnd(causalMask, windowMask)
|
||||
}
|
||||
|
||||
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
||||
type MoE struct {
|
||||
Router *nn.Linear `weight:"mlp.router"`
|
||||
TopK int32
|
||||
HiddenSize int32
|
||||
GroupSize int
|
||||
Bits int
|
||||
// Expert weights (loaded manually via sanitizeExpertWeights)
|
||||
GateBlocks, GateScales, GateBias *mlx.Array
|
||||
UpBlocks, UpScales, UpBias *mlx.Array
|
||||
DownBlocks, DownScales, DownBias *mlx.Array
|
||||
}
|
||||
|
||||
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
||||
logits := moe.Router.Forward(x)
|
||||
neg := mlx.Neg(logits)
|
||||
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
||||
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
||||
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
||||
weights := mlx.Softmax(topKVal, -1)
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
||||
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
||||
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
sorted := false
|
||||
n := B * L * moe.TopK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
sorted = true
|
||||
}
|
||||
|
||||
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
|
||||
if moe.GateBias != nil {
|
||||
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
||||
}
|
||||
if moe.UpBias != nil {
|
||||
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
||||
|
||||
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
if moe.DownBias != nil {
|
||||
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
||||
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
||||
}
|
||||
|
||||
type Block struct {
|
||||
Attention *Attention
|
||||
MLP *MoE
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
LayerType string // "sliding_attention" or "full_attention"
|
||||
}
|
||||
|
||||
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
||||
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead *nn.Linear `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
x := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
// Find representative cache indices for sliding window attention
|
||||
var swaIdx int = -1
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
swaIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create masks once at model level
|
||||
var fullMask, swaMask *mlx.Array
|
||||
var fullMaskMode, swaMaskMode string
|
||||
|
||||
if L > 1 {
|
||||
fullMaskMode = "causal"
|
||||
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
||||
c := caches[swaIdx]
|
||||
offset := c.Offset()
|
||||
windowSize := int(m.SlidingWindow)
|
||||
cacheLen := min(int(L), windowSize)
|
||||
if offset > 0 {
|
||||
cacheLen = min(c.Len()+int(L), windowSize)
|
||||
}
|
||||
if int(L) > windowSize {
|
||||
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
mask, maskMode := fullMask, fullMaskMode
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
mask, maskMode = swaMask, swaMaskMode
|
||||
}
|
||||
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
||||
}
|
||||
|
||||
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
||||
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
||||
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
||||
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
||||
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
||||
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
||||
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
||||
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
||||
|
||||
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
||||
|
||||
if gateUpBlocks != nil {
|
||||
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
s := gub.Shape()
|
||||
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpScales != nil {
|
||||
s := gateUpScales.Shape()
|
||||
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpBias != nil {
|
||||
s := gateUpBias.Shape()
|
||||
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
}
|
||||
if downBlocks != nil {
|
||||
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
}
|
||||
return moe
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load simple weights via struct tags
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers with custom MoE handling
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
layer := &Block{}
|
||||
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Initialize attention YaRN
|
||||
layer.Attention.initYarn(&cfg)
|
||||
|
||||
// Load MoE with weight sanitization
|
||||
moe := sanitizeExpertWeights(weights, prefix)
|
||||
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
||||
moe.TopK = cfg.NumExpertsPerTok
|
||||
moe.HiddenSize = cfg.HiddenSize
|
||||
layer.MLP = moe
|
||||
|
||||
// Set layer type
|
||||
layer.LayerType = "full_attention"
|
||||
if int(i) < len(cfg.LayerTypes) {
|
||||
layer.LayerType = cfg.LayerTypes[i]
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
||||
// this reduces peak memory by freeing mmap during materialization
|
||||
weights.ReleaseAll()
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int32 {
|
||||
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
||||
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
}
|
||||
return 131072
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Layer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"lm_head,optional"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Config)
|
||||
}
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
k, v = c.Update(k, v, int(L))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,348 +0,0 @@
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,866 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestTransformerConfig tests configuration invariants.
|
||||
func TestTransformerConfig(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Property: hidden_dim = n_heads * head_dim
|
||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: axes_dims_rope sums to head_dim
|
||||
var ropeSum int32
|
||||
for _, d := range cfg.AxesDimsRope {
|
||||
ropeSum += d
|
||||
}
|
||||
if ropeSum != cfg.HeadDim {
|
||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: in_channels = out_channels * patch_size^2
|
||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
||||
if cfg.InChannels != expectedIn {
|
||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
||||
func TestTransformerRoPE(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Test with small image dimensions
|
||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
||||
txtLen := int32(5)
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Verify shapes: [seq_len, head_dim]
|
||||
imgSeqLen := imgH * imgW
|
||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
||||
}
|
||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
||||
}
|
||||
|
||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
||||
}
|
||||
|
||||
// Verify values are finite
|
||||
imgData := ropeCache.ImgFreqs.Data()
|
||||
for i := 0; i < min(100, len(imgData)); i++ {
|
||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestTransformerForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
transformer := &Transformer{}
|
||||
if err := transformer.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load transformer: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(transformer)...)
|
||||
cfg := transformer.Config
|
||||
|
||||
// Small test inputs
|
||||
batchSize := int32(1)
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
imgSeqLen := imgH * imgW
|
||||
txtSeqLen := int32(5)
|
||||
|
||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
||||
|
||||
// Forward pass
|
||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
||||
gotShape := out.Shape()
|
||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,852 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
// Input is channels-last: [B, T, H, W, C]
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
// Works with channels-last [B, T, H, W, C] format
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
||||
var h *mlx.Array
|
||||
|
||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection (shortcut handles its own pools if present)
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
||||
|
||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// out: [B*T, 1, H*W, C]
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block with staged memory management
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// ResBlocks handle their own pools
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Upsampler handles its own pools
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block of decoder
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Each component handles its own pools; we just free inputs
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading post_quant_conv... ")
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PostQuantConv = postQuantConv
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = convIn
|
||||
fmt.Println("✓")
|
||||
|
||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
||||
fmt.Print(" Loading mid_block... ")
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MidBlock = midBlock
|
||||
fmt.Println("✓")
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
fmt.Print(" Loading up_blocks... ")
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.UpBlocks[i] = upBlock
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NormOut = normOut
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = convOut
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
||||
m := &VAEDecoder{}
|
||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] normalized latents
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Stage 1a: Denormalize and transpose
|
||||
{
|
||||
z = vae.Denormalize(z)
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// Stage 1b: PostQuantConv (handles its own pools)
|
||||
x = vae.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// Stage 1c: ConvIn (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 2: Mid block (handles its own pools)
|
||||
x = vae.MidBlock.Forward(x)
|
||||
|
||||
// Stage 3: Up blocks (each handles its own pools)
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// Stage 4a: NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = vae.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 4b: ConvOut (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 4c: Post-processing
|
||||
{
|
||||
prev := x
|
||||
// Clamp to [-1, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
// Convert back from channels-last to channels-first
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Denormalize reverses the normalization applied during encoding
|
||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
mean = mlx.ToBFloat16(mean)
|
||||
std = mlx.ToBFloat16(std)
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
||||
}
|
||||
|
||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
||||
|
||||
wShape := weight.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = weight
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
out := mlx.Linear(x, w)
|
||||
outC := w.Dim(1)
|
||||
out = mlx.Reshape(out, B, H, W, outC)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
||||
x = pad2D(x, 1, 1, 1, 1)
|
||||
return conv2D(x, weight, 1, 1)
|
||||
}
|
||||
|
||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
||||
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := w.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
||||
wFlat := mlx.Reshape(w, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 2)
|
||||
x = mlx.Take(x, colIdx, 3)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// Create repeat indices for rows
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
// Create repeat indices for columns
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
// Take along H (axis 1) then W (axis 2)
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
||||
// stride=1, padding=0 (we already padded manually)
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestVAEConfig tests configuration invariants.
|
||||
func TestVAEConfig(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Property: latents_mean and latents_std have z_dim elements
|
||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
||||
}
|
||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
||||
}
|
||||
|
||||
// Property: dim_mult defines 4 stages
|
||||
if len(cfg.DimMult) != 4 {
|
||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
||||
}
|
||||
|
||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
||||
if len(cfg.TemperalDownsample) != 3 {
|
||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
||||
func TestVAELatentsNormalization(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Verify latents_std values are all positive
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std <= 0 {
|
||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are in reasonable range (from actual model)
|
||||
for i, mean := range cfg.LatentsMean {
|
||||
if math.Abs(float64(mean)) > 5 {
|
||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
||||
}
|
||||
}
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std > 10 {
|
||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestVAEDecoderForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
vae := &VAEDecoder{}
|
||||
if err := vae.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(vae)...)
|
||||
|
||||
// Small test input: [B, C, T, H, W]
|
||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
||||
batchSize := int32(1)
|
||||
channels := int32(16)
|
||||
frames := int32(1)
|
||||
latentH := int32(4)
|
||||
latentW := int32(4)
|
||||
|
||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
||||
|
||||
// Decode
|
||||
out := vae.Decode(latents)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
||||
outShape := out.Shape()
|
||||
if outShape[0] != batchSize {
|
||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
||||
}
|
||||
if outShape[1] != 3 {
|
||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
||||
}
|
||||
if outShape[2] != frames {
|
||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
||||
}
|
||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
||||
expectedW := latentW * 16
|
||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
||||
outShape[3], outShape[4], expectedH, expectedW)
|
||||
}
|
||||
|
||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,680 +0,0 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape()
|
||||
|
||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
||||
if len(shape) == 4 {
|
||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
||||
return c.forward2D(x)
|
||||
}
|
||||
|
||||
// 3D conv: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
||||
kernelH := wShape[2]
|
||||
kernelW := wShape[3]
|
||||
outC := wShape[0]
|
||||
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad spatially
|
||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
||||
|
||||
// Apply 2D conv
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = mlx.Conv2d(x, weight, 1, 0)
|
||||
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Get output spatial dims
|
||||
outH := H
|
||||
outW := W
|
||||
|
||||
// Reshape back to [B, T, H, W, C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
qkv := mlx.Linear(x, w)
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0)
|
||||
out = mlx.Linear(out, p)
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
|
||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
||||
// strideT, strideH, strideW are the strides for each dimension
|
||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
// I := wShape[1]
|
||||
kT := wShape[2]
|
||||
kH := wShape[3]
|
||||
kW := wShape[4]
|
||||
|
||||
// For temporal: if T < kT, we need to repeat frames temporally
|
||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
||||
if T < kT {
|
||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
||||
T = kT
|
||||
}
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
||||
|
||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// extractPatches3DStrided extracts 3D patches with given strides
|
||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
numPatches := outT * outH * outW
|
||||
patches := make([]*mlx.Array, numPatches)
|
||||
idx := 0
|
||||
for t := int32(0); t < outT; t++ {
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startT := t * strideT
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
// Extract patch: [B, kT, kH, kW, C]
|
||||
patch := mlx.Slice(x,
|
||||
[]int32{0, startT, startH, startW, 0},
|
||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
||||
// Flatten to [B, C*T*H*W]
|
||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
||||
}
|
||||
|
||||
// extractPatches2DStrided extracts patches with given stride
|
||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * stride
|
||||
startW := j * stride
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// loadImageFile loads an image from disk
|
||||
func loadImageFile(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
||||
pixels := make([]float32, width*height*3)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
pixels[idx] = float32(r) / 65535.0
|
||||
pixels[idx+1] = float32(g) / 65535.0
|
||||
pixels[idx+2] = float32(b) / 65535.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
return pixels
|
||||
}
|
||||
|
||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
||||
}
|
||||
|
||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
// Add batch dimension [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0)
|
||||
// Convert to bf16
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
||||
func clampFloat(v, weightSum float64) uint8 {
|
||||
v /= weightSum
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
if v > 255 {
|
||||
v = 255
|
||||
}
|
||||
return uint8(math.Round(v))
|
||||
}
|
||||
|
||||
// ImageDims holds dimensions for a preprocessed image
|
||||
type ImageDims struct {
|
||||
// Original image dimensions
|
||||
OrigW, OrigH int32
|
||||
// Condition image dimensions (for vision encoder)
|
||||
CondW, CondH int32
|
||||
// VAE image dimensions
|
||||
VaeW, VaeH int32
|
||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
||||
LatentW, LatentH int32
|
||||
// Patch dimensions (latent dims / patch_size)
|
||||
PatchW, PatchH int32
|
||||
}
|
||||
|
||||
// ProcessorConfig holds image processor configuration
|
||||
type ProcessorConfig struct {
|
||||
// Condition image size (target pixel area for vision encoder input)
|
||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
||||
// Pipeline resizes image to this area before passing to encode_prompt
|
||||
ConditionImageSize int32
|
||||
|
||||
// VAE image size (target pixel area)
|
||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
||||
VAEImageSize int32
|
||||
|
||||
// Image normalization (ImageNet stats for vision encoder)
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// defaultProcessorConfig returns default processor config
|
||||
func defaultProcessorConfig() *ProcessorConfig {
|
||||
return &ProcessorConfig{
|
||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
||||
type Processor struct {
|
||||
Config *ProcessorConfig
|
||||
}
|
||||
|
||||
// Load loads the processor config
|
||||
func (p *Processor) Load(path string) error {
|
||||
p.Config = defaultProcessorConfig()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := bounds.Dx()
|
||||
origH := bounds.Dy()
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize
|
||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
|
||||
return condImage, vaeImage, nil
|
||||
}
|
||||
|
||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
||||
// Used by edit_plus pipeline for multi-image input
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImage converts image to tensor for vision encoder
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
||||
resized := resizeImageBicubic(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
if normalize {
|
||||
arr = p.normalizeImageNet(arr)
|
||||
}
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
|
||||
// Scale to [-1, 1]: arr * 2 - 1
|
||||
arr = mlx.MulScalar(arr, 2.0)
|
||||
arr = mlx.AddScalar(arr, -1.0)
|
||||
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
|
||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
||||
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
||||
// Round to factor
|
||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
||||
|
||||
// Ensure minimum factor size
|
||||
if hBar < factor {
|
||||
hBar = factor
|
||||
}
|
||||
if wBar < factor {
|
||||
wBar = factor
|
||||
}
|
||||
|
||||
// Check pixel constraints
|
||||
total := hBar * wBar
|
||||
if total > maxPixels {
|
||||
// Scale down
|
||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
||||
} else if total < minPixels {
|
||||
// Scale up
|
||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
||||
width := math.Sqrt(float64(targetArea) * ratio)
|
||||
height := width / ratio
|
||||
|
||||
m := float64(multiple)
|
||||
width = math.Round(width/m) * m
|
||||
height = math.Round(height/m) * m
|
||||
|
||||
// Ensure minimum dimensions
|
||||
if width < m {
|
||||
width = m
|
||||
}
|
||||
if height < m {
|
||||
height = m
|
||||
}
|
||||
|
||||
return int32(width), int32(height)
|
||||
}
|
||||
|
||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
||||
lanczos3 := &draw.Kernel{
|
||||
Support: 3.0,
|
||||
At: func(t float64) float64 {
|
||||
if t == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if t < 0 {
|
||||
t = -t
|
||||
}
|
||||
if t >= 3.0 {
|
||||
return 0.0
|
||||
}
|
||||
// sinc(t) * sinc(t/3)
|
||||
piT := math.Pi * t
|
||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
||||
},
|
||||
}
|
||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
// Convert to RGBA if needed
|
||||
var src *image.RGBA
|
||||
if rgba, ok := img.(*image.RGBA); ok {
|
||||
src = rgba
|
||||
} else {
|
||||
src = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
src.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
||||
cubic := func(x float64) float64 {
|
||||
if x < 0 {
|
||||
x = -x
|
||||
}
|
||||
if x < 1 {
|
||||
return 1.5*x*x*x - 2.5*x*x + 1
|
||||
}
|
||||
if x < 2 {
|
||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Horizontal pass: srcW -> width, keep srcH rows
|
||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
||||
for y := 0; y < srcH; y++ {
|
||||
for dstX := 0; dstX < width; dstX++ {
|
||||
// PIL coordinate mapping: center-to-center
|
||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
||||
baseX := int(math.Floor(srcXf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for i := -1; i <= 2; i++ {
|
||||
sx := baseX + i
|
||||
if sx < 0 {
|
||||
sx = 0
|
||||
}
|
||||
if sx >= srcW {
|
||||
sx = srcW - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
||||
c := src.RGBAAt(sx, y)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
temp.SetRGBA(dstX, y, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical pass: srcH -> height
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for x := 0; x < width; x++ {
|
||||
for dstY := 0; dstY < height; dstY++ {
|
||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
||||
baseY := int(math.Floor(srcYf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for j := -1; j <= 2; j++ {
|
||||
sy := baseY + j
|
||||
if sy < 0 {
|
||||
sy = 0
|
||||
}
|
||||
if sy >= srcH {
|
||||
sy = srcH - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
||||
c := temp.RGBAAt(x, sy)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
dst.SetRGBA(x, dstY, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
||||
const vaeScaleFactor int32 = 8
|
||||
const patchSize int32 = 2
|
||||
|
||||
condImages := make([]*mlx.Array, len(imagePaths))
|
||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
||||
dims := make([]ImageDims, len(imagePaths))
|
||||
|
||||
for i, imagePath := range imagePaths {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := int32(bounds.Dx())
|
||||
origH := int32(bounds.Dy())
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Calculate derived dimensions
|
||||
latentW := vaeW / vaeScaleFactor
|
||||
latentH := vaeH / vaeScaleFactor
|
||||
patchW := latentW / patchSize
|
||||
patchH := latentH / patchSize
|
||||
|
||||
dims[i] = ImageDims{
|
||||
OrigW: origW,
|
||||
OrigH: origH,
|
||||
CondW: condW,
|
||||
CondH: condH,
|
||||
VaeW: vaeW,
|
||||
VaeH: vaeH,
|
||||
LatentW: latentW,
|
||||
LatentH: latentH,
|
||||
PatchW: patchW,
|
||||
PatchH: patchH,
|
||||
}
|
||||
|
||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
}
|
||||
|
||||
return condImages, vaeImages, dims, nil
|
||||
}
|
||||
@@ -1,608 +0,0 @@
|
||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
||||
// It reuses components from qwen_image where possible.
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Processor *Processor // Image processor for vision encoder
|
||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
||||
VAE *VAE // Combined encoder + decoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image-Edit model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer from processor directory
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
processorPath := filepath.Join(modelPath, "processor")
|
||||
tok, err := tokenizer.Load(processorPath)
|
||||
if err != nil {
|
||||
// Fallback to tokenizer directory
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err = tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load processor (image preprocessing config)
|
||||
fmt.Print(" Loading processor... ")
|
||||
m.Processor = &Processor{}
|
||||
if err := m.Processor.Load(processorPath); err != nil {
|
||||
return fmt.Errorf("processor: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer (reuse qwen_image)
|
||||
m.Transformer = &qwen_image.Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE (encoder + decoder)
|
||||
m.VAE = &VAE{}
|
||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edit edits an image based on a text prompt.
|
||||
// inputImagePath: path to input image
|
||||
// prompt: text description of desired edit
|
||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// EditFromConfig edits images using the unified config struct.
|
||||
// Accepts one or more input images.
|
||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if len(inputImagePaths) == 0 {
|
||||
return nil, fmt.Errorf("no input images provided")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := m.edit(inputImagePaths, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EditImage implements model.ImageEditModel interface.
|
||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// EditMultiImage edits using multiple source images.
|
||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
return m.EditFromConfig(inputImagePaths, cfg)
|
||||
}
|
||||
|
||||
// edit is the internal editing pipeline that handles one or more images.
|
||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
|
||||
// Load and preprocess all input images
|
||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
||||
}
|
||||
for _, img := range condImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
for _, img := range vaeImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
mlx.Eval(append(condImages, vaeImages...)...)
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
vaeScaleFactor := int32(8)
|
||||
|
||||
// Output dimensions - if not specified, use first input image dimensions
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = inputDims[0].VaeW
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = inputDims[0].VaeH
|
||||
}
|
||||
|
||||
// Output (noise) latent dimensions
|
||||
outLatentH := cfg.Height / vaeScaleFactor
|
||||
outLatentW := cfg.Width / vaeScaleFactor
|
||||
outPH := outLatentH / tcfg.PatchSize
|
||||
outPW := outLatentW / tcfg.PatchSize
|
||||
noiseSeqLen := outPH * outPW
|
||||
imgSeqLen := noiseSeqLen
|
||||
|
||||
// Encode prompt with all images for conditioning
|
||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
|
||||
var negEmb *mlx.Array
|
||||
if useCFG {
|
||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(negEmb)
|
||||
mlx.Eval(negEmb)
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
for i, vaeImage := range vaeImages {
|
||||
imageLatents := m.VAE.Encode(vaeImage)
|
||||
imageLatents = m.VAE.Normalize(imageLatents)
|
||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
||||
mlx.Keep(packed)
|
||||
mlx.Eval(packed)
|
||||
allImageLatentsPacked[i] = packed
|
||||
}
|
||||
|
||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
||||
mlx.Keep(imageLatentsPacked)
|
||||
mlx.Eval(imageLatentsPacked)
|
||||
|
||||
// Scheduler
|
||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
||||
|
||||
// Init noise latents in packed format
|
||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// RoPE cache
|
||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
mlx.Eval(timestep)
|
||||
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
||||
}
|
||||
|
||||
// Free denoising temporaries
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
// Decode latents
|
||||
decoded := m.decodeAndPostprocess(latents)
|
||||
latents.Free()
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
||||
// This prevents CFG from inflating magnitude too much.
|
||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
||||
// Upcast to float32 for precision
|
||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
||||
|
||||
// CFG: pred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posF32, negF32)
|
||||
scaledDiff := mlx.MulScalar(diff, scale)
|
||||
combPred := mlx.Add(negF32, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
|
||||
mlx.Eval(output)
|
||||
return mlx.ToBFloat16(output)
|
||||
}
|
||||
|
||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
||||
latents = m.VAE.Denormalize(latents)
|
||||
decoded := m.VAE.Decode(latents)
|
||||
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
mlx.Eval(decoded)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
||||
// Handles single or multiple input images with different resolutions.
|
||||
//
|
||||
// Parameters:
|
||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
||||
// - txtLen: text sequence length
|
||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
||||
//
|
||||
// Returns RoPE cache where:
|
||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
||||
// - Following positions are for each input image (interpolated from output res)
|
||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
|
||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame position
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = posFreqsT[framePos][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame -1: use last row of negative frame frequencies
|
||||
negFrameIdx := maxIdx - 1
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Total image sequence length: noise + all input images
|
||||
noiseSeqLen := outPH * outPW
|
||||
totalImgLen := noiseSeqLen
|
||||
for _, dims := range inputDims {
|
||||
totalImgLen += dims.PatchH * dims.PatchW
|
||||
}
|
||||
|
||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
||||
idx := int32(0)
|
||||
|
||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
||||
for y := int32(0); y < outPH; y++ {
|
||||
for x := int32(0); x < outPW; x++ {
|
||||
row := computePosFreqs(0, y, x)
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
|
||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
||||
// (_compute_condition_freqs), positive indices for others.
|
||||
numImages := len(inputDims)
|
||||
lastImgIdx := numImages - 1
|
||||
for imgIdx, dims := range inputDims {
|
||||
inPH := dims.PatchH
|
||||
inPW := dims.PatchW
|
||||
|
||||
// Determine frame index for this image
|
||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
||||
|
||||
// Map each input position to an output position using linear interpolation
|
||||
for y := int32(0); y < inPH; y++ {
|
||||
for x := int32(0); x < inPW; x++ {
|
||||
// Interpolate: map input (y, x) to output grid position
|
||||
// This is the key fix from DiffSynth's forward_sampling
|
||||
var yOut, xOut int32
|
||||
if inPH == 1 {
|
||||
yOut = 0
|
||||
} else {
|
||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
||||
yOut = y * (outPH - 1) / (inPH - 1)
|
||||
}
|
||||
if inPW == 1 {
|
||||
xOut = 0
|
||||
} else {
|
||||
xOut = x * (outPW - 1) / (inPW - 1)
|
||||
}
|
||||
|
||||
var row []float32
|
||||
if useNegFrame {
|
||||
// Last image in multi-image uses frame -1
|
||||
row = computeNegFrameFreqs(yOut, xOut)
|
||||
} else {
|
||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
||||
frameIdx := int32(imgIdx + 1)
|
||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
||||
}
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies - start after max video index
|
||||
maxVidIdx := max(outPH/2, outPW/2)
|
||||
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &qwen_image.RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
@@ -1,225 +0,0 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
// Expected values from Python:
|
||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
||||
expectedFreqsT := []float64{
|
||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
||||
}
|
||||
|
||||
expectedFreqsH_first4 := []float64{
|
||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
||||
}
|
||||
|
||||
expectedFreqsH_last4 := []float64{
|
||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
||||
}
|
||||
|
||||
// Test temporal frequencies (dim=16)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
if len(freqsT) != 8 {
|
||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
||||
}
|
||||
for i, expected := range expectedFreqsT {
|
||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test height/width frequencies (dim=56)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
||||
if len(freqsH) != 28 {
|
||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
||||
}
|
||||
for i, expected := range expectedFreqsH_first4 {
|
||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
||||
}
|
||||
}
|
||||
for i, expected := range expectedFreqsH_last4 {
|
||||
idx := 24 + i // last 4 of 28
|
||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
||||
func TestMakeFreqTable(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Test positive table
|
||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
|
||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
||||
if posTable[0][i] != 1.0 {
|
||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
||||
}
|
||||
if posTable[0][i+1] != 0.0 {
|
||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
||||
}
|
||||
}
|
||||
|
||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
||||
}
|
||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
||||
}
|
||||
|
||||
// Test negative table
|
||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
||||
|
||||
// negTable[4095] corresponds to position -1
|
||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
||||
}
|
||||
|
||||
// negTable[4094] corresponds to position -2
|
||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
||||
cos2 := math.Cos(2.0)
|
||||
sin2 := math.Sin(2.0)
|
||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
||||
if !mlx.GPUIsAvailable() {
|
||||
t.Skip("GPU not available")
|
||||
}
|
||||
|
||||
mlx.SetDefaultDeviceCPU()
|
||||
|
||||
// 4x4 patch grid, single image
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
txtLen := int32(5)
|
||||
axesDims := []int32{16, 56, 56}
|
||||
|
||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
||||
|
||||
// Check shapes
|
||||
imgShape := cache.ImgFreqs.Shape()
|
||||
if imgShape[0] != 16 { // 4*4 patches
|
||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
||||
}
|
||||
|
||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
||||
mlx.Eval(imgFreqsCPU)
|
||||
imgData := imgFreqsCPU.Data()
|
||||
|
||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
||||
for i := 0; i < 16; i += 2 {
|
||||
cosVal := imgData[i]
|
||||
sinVal := imgData[i+1]
|
||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
||||
}
|
||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
||||
}
|
||||
}
|
||||
|
||||
cache.ImgFreqs.Free()
|
||||
cache.TxtFreqs.Free()
|
||||
}
|
||||
|
||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
||||
func TestScaleRopePositions(t *testing.T) {
|
||||
// For a 4x4 grid with scale_rope=True:
|
||||
// hHalf = 2, wHalf = 2
|
||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
//
|
||||
// Height positions:
|
||||
// y=0: -(4-2) + 0 = -2
|
||||
// y=1: -(4-2) + 1 = -1
|
||||
// y=2: 2 - 2 = 0
|
||||
// y=3: 3 - 2 = 1
|
||||
//
|
||||
// Same for width
|
||||
|
||||
pH, pW := int32(4), int32(4)
|
||||
hHalf := pH / 2
|
||||
wHalf := pW / 2
|
||||
hNegCount := pH - hHalf
|
||||
wNegCount := pW - wHalf
|
||||
|
||||
expectedH := []int32{-2, -1, 0, 1}
|
||||
expectedW := []int32{-2, -1, 0, 1}
|
||||
|
||||
for y := int32(0); y < pH; y++ {
|
||||
var hPos int32
|
||||
if y < hNegCount {
|
||||
hPos = -(pH - hHalf) + y
|
||||
} else {
|
||||
hPos = y - hNegCount
|
||||
}
|
||||
if hPos != expectedH[y] {
|
||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
||||
}
|
||||
}
|
||||
|
||||
for x := int32(0); x < pW; x++ {
|
||||
var wPos int32
|
||||
if x < wNegCount {
|
||||
wPos = -(pW - wHalf) + x
|
||||
} else {
|
||||
wPos = x - wNegCount
|
||||
}
|
||||
if wPos != expectedW[x] {
|
||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
||||
func TestRoPEHeadDimensions(t *testing.T) {
|
||||
// axes_dims_rope = [16, 56, 56]
|
||||
// Each dimension uses half the values for frequencies
|
||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
||||
|
||||
axesDims := []int32{16, 56, 56}
|
||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
||||
expectedHeadDim := expectedFreqs * 2
|
||||
|
||||
if expectedFreqs != 64 {
|
||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
||||
}
|
||||
if expectedHeadDim != 128 {
|
||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
||||
}
|
||||
|
||||
// This should match the transformer's attention head dimension
|
||||
// hidden_size = 3072, num_heads = 24
|
||||
// head_dim = 3072 / 24 = 128
|
||||
}
|
||||
|
||||
@@ -1,640 +0,0 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// VAE is the full VAE with encoder and decoder
|
||||
type VAE struct {
|
||||
Config *VAEConfig
|
||||
Encoder *VAEEncoder
|
||||
Decoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the VAE from a directory
|
||||
func (m *VAE) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load weights as f32 for quality (matches Python default behavior)
|
||||
// VAE decoder precision is critical for final image quality
|
||||
fmt.Print(" Loading weights as f32... ")
|
||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
// Load encoder
|
||||
fmt.Print(" Loading encoder... ")
|
||||
m.Encoder = &VAEEncoder{}
|
||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load decoder
|
||||
fmt.Print(" Loading decoder... ")
|
||||
m.Decoder = &VAEDecoder{}
|
||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("decoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
||||
return m.Encoder.Encode(x)
|
||||
}
|
||||
|
||||
// Decode decodes latents to image
|
||||
// z: [B, C, T, H, W] latents (denormalized)
|
||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
||||
return m.Decoder.Decode(z)
|
||||
}
|
||||
|
||||
// Normalize applies latent normalization
|
||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
// Mean/std are f32, will match z dtype through broadcasting
|
||||
return mlx.Div(mlx.Sub(z, mean), std)
|
||||
}
|
||||
|
||||
// Denormalize reverses latent normalization
|
||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
// Convert latents to f32 for VAE decoder quality
|
||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// VAEEncoder is the encoder part of the VAE
|
||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
||||
// - Blocks 0,1: ResBlocks (base_dim)
|
||||
// - Block 2: Downsample
|
||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
||||
// - Block 5: Downsample + temporal
|
||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
||||
// - Block 8: Downsample + temporal
|
||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
||||
type VAEEncoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
ConvIn *CausalConv3d
|
||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
||||
MidBlock *MidBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
QuantConv *CausalConv3d
|
||||
}
|
||||
|
||||
// EncoderBlock is either a ResBlock or a Downsample
|
||||
type EncoderBlock interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
IsDownsample() bool
|
||||
}
|
||||
|
||||
// EncoderResBlock wraps ResBlock
|
||||
type EncoderResBlock struct {
|
||||
*ResBlock
|
||||
}
|
||||
|
||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
||||
|
||||
// EncoderDownsample is a downsample layer
|
||||
type EncoderDownsample struct {
|
||||
Resample *CausalConv3d
|
||||
TimeConv *CausalConv3d // Optional temporal downsample
|
||||
}
|
||||
|
||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
||||
|
||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Spatial downsample with stride 2
|
||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
||||
x = d.forwardSpatialDownsample(x)
|
||||
|
||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
||||
// Since we don't support streaming, we skip time_conv entirely.
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := d.Resample.Weight.Shape()
|
||||
outC := wShape[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
||||
|
||||
// Apply 2D conv with stride 2
|
||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
|
||||
if d.Resample.Bias != nil {
|
||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
||||
outH := (H + 1) / 2
|
||||
outW := (W + 1) / 2
|
||||
|
||||
// Reshape back to [B, T, H', W', C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// loadFromWeights loads the encoder from pre-loaded weights
|
||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
e.Config = cfg
|
||||
|
||||
// Conv in
|
||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvIn = convIn
|
||||
|
||||
// Encoder uses flat block structure:
|
||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
||||
|
||||
// Track dimensions
|
||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
||||
blockIdx := 0
|
||||
|
||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
||||
inDim := cfg.BaseDim
|
||||
if stage > 0 {
|
||||
inDim = dims[stage-1]
|
||||
}
|
||||
outDim := dims[stage]
|
||||
|
||||
// ResBlocks for this stage (num_res_blocks per stage)
|
||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
currentInDim := inDim
|
||||
if r > 0 {
|
||||
currentInDim = outDim
|
||||
}
|
||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, block)
|
||||
blockIdx++
|
||||
}
|
||||
|
||||
// Downsample after each stage except the last
|
||||
if stage < len(cfg.DimMult)-1 {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, down)
|
||||
blockIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.MidBlock = midBlock
|
||||
|
||||
// Norm out
|
||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.NormOut = normOut
|
||||
|
||||
// Conv out
|
||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvOut = convOut
|
||||
|
||||
// Quant conv
|
||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.QuantConv = quantConv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncoderResBlock{block}, nil
|
||||
}
|
||||
|
||||
// newEncoderDownsample creates a downsample layer for the encoder
|
||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var timeConv *CausalConv3d
|
||||
if temporal {
|
||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
||||
}
|
||||
|
||||
return &EncoderDownsample{
|
||||
Resample: resample,
|
||||
TimeConv: timeConv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(x)
|
||||
|
||||
// Conv in
|
||||
x = e.ConvIn.Forward(x)
|
||||
|
||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
||||
for _, block := range e.Blocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = e.MidBlock.Forward(x)
|
||||
|
||||
// Norm + silu
|
||||
{
|
||||
prev := x
|
||||
x = e.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Conv out
|
||||
{
|
||||
prev := x
|
||||
x = e.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Quant conv
|
||||
{
|
||||
prev := x
|
||||
x = e.QuantConv.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Get mode from distribution (first half of channels = mean)
|
||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
||||
shape := x.Shape()
|
||||
latentC := shape[4] / 2
|
||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
||||
|
||||
// Convert back to channels-first [N, C, T, H, W]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the decoder part of the VAE
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// loadFromWeights loads the decoder from pre-loaded weights
|
||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
d.Config = cfg
|
||||
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PostQuantConv = postQuantConv
|
||||
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvIn = convIn
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.MidBlock = midBlock
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.UpBlocks[i] = upBlock
|
||||
}
|
||||
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NormOut = normOut
|
||||
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvOut = convOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] denormalized latents
|
||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Convert from channels-first to channels-last
|
||||
{
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// PostQuantConv
|
||||
x = d.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// ConvIn
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = d.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks
|
||||
for _, upBlock := range d.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = d.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// ConvOut
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Post-processing: clamp and convert back to channels-first
|
||||
{
|
||||
prev := x
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// DownBlock handles downsampling in encoder
|
||||
type DownBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Downsampler *Downsample
|
||||
}
|
||||
|
||||
// newDownBlock creates a down block
|
||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var downsampler *Downsample
|
||||
if downsampleMode != "" {
|
||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
||||
}
|
||||
|
||||
return &DownBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Downsampler: downsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies down block
|
||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range d.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if d.Downsampler != nil {
|
||||
prev := x
|
||||
x = d.Downsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Downsample handles spatial downsampling
|
||||
type Downsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newDownsample creates a downsampler
|
||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Downsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := d.Conv.Shape()[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
||||
// For 3x3 stride 2: pad 1 on all sides
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
|
||||
// Conv with stride 2 using manual strided patching
|
||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
if d.Bias != nil {
|
||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
|
||||
}
|
||||
|
||||
// DefaultFlowMatchSchedulerConfig returns default config
|
||||
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0,
|
||||
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
|
||||
// This is used in Z-Image-Turbo for fast sampling
|
||||
type FlowMatchEulerScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchEulerScheduler creates a new scheduler
|
||||
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
|
||||
return &FlowMatchEulerScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
|
||||
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
|
||||
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i <= numSteps; i++ {
|
||||
t := 1.0 - float32(i)/float32(numSteps)
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
t = s.timeShift(mu, t)
|
||||
}
|
||||
|
||||
s.Timesteps[i] = t
|
||||
s.Sigmas[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (match Python)
|
||||
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity/noise from the model
|
||||
// timestepIdx: current timestep index
|
||||
// sample: current noisy sample
|
||||
// Returns: denoised sample for next step
|
||||
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
// where v_t is the velocity predicted by the model
|
||||
dt := sigmaNext - sigma // This is negative (going from noise to clean)
|
||||
|
||||
// x_next = x + dt * velocity
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// ScaleSample scales the sample for model input (identity for flow matching)
|
||||
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Flow matching doesn't need scaling
|
||||
return sample
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps (implements Scheduler interface)
|
||||
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep
|
||||
// Used for img2img or inpainting
|
||||
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-t) * x_0 + t * noise
|
||||
t := s.Timesteps[timestepIdx]
|
||||
oneMinusT := 1.0 - t
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
|
||||
scaledNoise := mlx.MulScalar(noise, t)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return RandomNormal(shape, seed)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal tensor using MLX
|
||||
func RandomNormal(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
|
||||
// Latent is 8x smaller than image (VAE downscale)
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
|
||||
return []int32{batchSize, latentChannels, latentH, latentW}
|
||||
}
|
||||
@@ -1,294 +0,0 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// loadQwen3Config loads text encoder config from a JSON file
|
||||
func loadQwen3Config(path string) (*Qwen3Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from a directory
|
||||
func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen3 text encoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = cfg
|
||||
|
||||
// Pre-allocate layers slice
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
@@ -1,690 +0,0 @@
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Z-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
Dim int32 `json:"dim"`
|
||||
NHeads int32 `json:"n_heads"`
|
||||
NKVHeads int32 `json:"n_kv_heads"`
|
||||
NLayers int32 `json:"n_layers"`
|
||||
NRefinerLayers int32 `json:"n_refiner_layers"`
|
||||
InChannels int32 `json:"in_channels"`
|
||||
PatchSize int32 `json:"-"` // Computed from AllPatchSize
|
||||
CapFeatDim int32 `json:"cap_feat_dim"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TScale float32 `json:"t_scale"`
|
||||
QKNorm bool `json:"qk_norm"`
|
||||
AxesDims []int32 `json:"axes_dims"`
|
||||
AxesLens []int32 `json:"axes_lens"`
|
||||
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
// t: [B] timesteps
|
||||
|
||||
// Create sinusoidal embedding
|
||||
half := te.FreqEmbedSize / 2
|
||||
|
||||
// freqs = exp(-log(10000) * arange(half) / half)
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// t[:, None] * freqs[None, :] -> [B, half]
|
||||
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// embedding = [cos(args), sin(args)] -> [B, 256]
|
||||
cosArgs := mlx.Cos(args)
|
||||
sinArgs := mlx.Sin(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := te.Linear1.Forward(embedding)
|
||||
h = mlx.SiLU(h)
|
||||
h = te.Linear2.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, L, in_channels * 4] -> [B, L, dim]
|
||||
return xe.Linear.Forward(x)
|
||||
}
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
// RMSNorm on last axis (uses 1e-6)
|
||||
h := ce.Norm.Forward(capFeats, 1e-6)
|
||||
// Linear projection
|
||||
return ce.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
out := ff.W2.Forward(h)
|
||||
|
||||
return mlx.Reshape(out, B, L, ff.OutDim)
|
||||
}
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
|
||||
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
|
||||
|
||||
// Apply RoPE if provided
|
||||
if cos != nil && sin != nil {
|
||||
q = applyRoPE3D(q, cos, sin)
|
||||
k = applyRoPE3D(k, cos, sin)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B*L, attn.Dim)
|
||||
out = attn.ToOut.Forward(out)
|
||||
|
||||
return mlx.Reshape(out, B, L, attn.Dim)
|
||||
}
|
||||
|
||||
// applyRoPE3D applies 3-axis rotary position embeddings
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [B, L, 1, head_dim/2]
|
||||
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Create even/odd index arrays
|
||||
evenIdx := make([]int32, half)
|
||||
oddIdx := make([]int32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
evenIdx[i] = i * 2
|
||||
oddIdx[i] = i*2 + 1
|
||||
}
|
||||
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
|
||||
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
|
||||
|
||||
// Extract x1 (even indices) and x2 (odd indices) along last axis
|
||||
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
|
||||
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
|
||||
|
||||
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
|
||||
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
|
||||
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
|
||||
|
||||
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
|
||||
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
|
||||
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
|
||||
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
|
||||
return mlx.Reshape(stacked, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
|
||||
if tb.AdaLN != nil && adaln != nil {
|
||||
// Compute modulation: [B, 256] -> [B, 4*dim]
|
||||
chunks := tb.AdaLN.Forward(adaln)
|
||||
|
||||
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
|
||||
chunkShape := chunks.Shape()
|
||||
chunkDim := chunkShape[1] / 4
|
||||
|
||||
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
|
||||
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
|
||||
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
|
||||
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
|
||||
|
||||
// Expand for broadcasting: [B, 1, dim]
|
||||
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
|
||||
gateMSA = mlx.ExpandDims(gateMSA, 1)
|
||||
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
|
||||
gateMLP = mlx.ExpandDims(gateMLP, 1)
|
||||
|
||||
// Attention with modulation
|
||||
normX := tb.AttentionNorm1.Forward(x, eps)
|
||||
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
|
||||
attnOut := tb.Attention.Forward(normX, cos, sin)
|
||||
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
|
||||
|
||||
// FFN with modulation
|
||||
normFFN := tb.FFNNorm1.Forward(x, eps)
|
||||
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
|
||||
ffnOut := tb.FeedForward.Forward(normFFN)
|
||||
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
|
||||
} else {
|
||||
// No modulation (context refiner)
|
||||
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
|
||||
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
|
||||
|
||||
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
|
||||
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
|
||||
// c: [B, 256] -> scale: [B, dim]
|
||||
scale := mlx.SiLU(c)
|
||||
scale = fl.AdaLN.Forward(scale)
|
||||
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
|
||||
|
||||
// LayerNorm (affine=False) then scale
|
||||
x = layerNormNoAffine(x, 1e-6)
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
|
||||
// Output projection
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
x = fl.Output.Forward(x)
|
||||
|
||||
return mlx.Reshape(x, B, L, fl.OutDim)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Z-Image DiT model
|
||||
type Transformer struct {
|
||||
TEmbed *TimestepEmbedder `weight:"t_embedder"`
|
||||
XEmbed *XEmbedder `weight:"all_x_embedder"`
|
||||
CapEmbed *CapEmbedder `weight:"cap_embedder"`
|
||||
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
|
||||
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
|
||||
Layers []*TransformerBlock `weight:"layers"`
|
||||
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
|
||||
XPadToken *mlx.Array `weight:"x_pad_token"`
|
||||
CapPadToken *mlx.Array `weight:"cap_pad_token"`
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Z-Image transformer...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = cfg
|
||||
|
||||
// Pre-allocate slices for loader
|
||||
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.ContextRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.Layers {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTransformerConfig loads transformer config from a JSON file
|
||||
func loadTransformerConfig(path string) (*TransformerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg TransformerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
// Extract PatchSize from array
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
block.Dim = cfg.Dim
|
||||
block.HasModulation = block.AdaLN != nil
|
||||
|
||||
// Init attention computed fields
|
||||
attn := block.Attention
|
||||
attn.NHeads = cfg.NHeads
|
||||
attn.HeadDim = cfg.Dim / cfg.NHeads
|
||||
attn.Dim = cfg.Dim
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
block.AttentionNorm2.Eps = cfg.NormEps
|
||||
block.FFNNorm1.Eps = cfg.NormEps
|
||||
block.FFNNorm2.Eps = cfg.NormEps
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE values
|
||||
type RoPECache struct {
|
||||
ImgCos *mlx.Array
|
||||
ImgSin *mlx.Array
|
||||
CapCos *mlx.Array
|
||||
CapSin *mlx.Array
|
||||
UnifiedCos *mlx.Array
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
|
||||
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
imgLen := hTok * wTok
|
||||
|
||||
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
|
||||
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
|
||||
imgPos = mlx.ToBFloat16(imgPos)
|
||||
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
|
||||
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
|
||||
capPos = mlx.ToBFloat16(capPos)
|
||||
|
||||
// Compute RoPE from UNIFIED positions
|
||||
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
|
||||
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
|
||||
|
||||
// Slice RoPE for image and caption parts
|
||||
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
|
||||
return &RoPECache{
|
||||
ImgCos: imgCos,
|
||||
ImgSin: imgSin,
|
||||
CapCos: capCos,
|
||||
CapSin: capSin,
|
||||
UnifiedCos: unifiedCos,
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Z-Image transformer with precomputed RoPE
|
||||
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Embed caption features -> [B, L_cap, dim]
|
||||
capEmb := m.CapEmbed.Forward(capFeats)
|
||||
|
||||
eps := m.NormEps
|
||||
|
||||
// Noise refiner: refine image patches with modulation
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Context refiner: refine caption (no modulation)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Main transformer layers use full unified RoPE
|
||||
for _, layer := range m.Layers {
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for faster inference.
|
||||
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
|
||||
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
|
||||
func (m *Transformer) ForwardWithCache(
|
||||
x *mlx.Array,
|
||||
t *mlx.Array,
|
||||
capFeats *mlx.Array,
|
||||
rope *RoPECache,
|
||||
stepCache *cache.StepCache,
|
||||
step int,
|
||||
cacheInterval int,
|
||||
) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
cacheLayers := stepCache.NumLayers()
|
||||
eps := m.NormEps
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Context refiners: compute once on step 0, reuse forever
|
||||
// (caption embedding doesn't depend on timestep or latents)
|
||||
var capEmb *mlx.Array
|
||||
if stepCache.GetConstant() != nil {
|
||||
capEmb = stepCache.GetConstant()
|
||||
} else {
|
||||
capEmb = m.CapEmbed.Forward(capFeats)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
stepCache.SetConstant(capEmb)
|
||||
}
|
||||
|
||||
// Noise refiners: always compute (depend on x which changes each step)
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Determine if this is a cache refresh step
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
// Main transformer layers with caching
|
||||
for i, layer := range m.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached output for shallow layers
|
||||
unified = stepCache.Get(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
// Cache shallow layer outputs on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, unified)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
|
||||
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
|
||||
// Create meshgrid and stack
|
||||
total := d0 * d1 * d2
|
||||
coords := make([]float32, total*3)
|
||||
|
||||
idx := 0
|
||||
for i := int32(0); i < d0; i++ {
|
||||
for j := int32(0); j < d1; j++ {
|
||||
for k := int32(0); k < d2; k++ {
|
||||
coords[idx*3+0] = float32(s0 + i)
|
||||
coords[idx*3+1] = float32(s1 + j)
|
||||
coords[idx*3+2] = float32(s2 + k)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArray(coords, []int32{1, total, 3})
|
||||
}
|
||||
|
||||
// prepareRoPE3D computes cos/sin for 3-axis RoPE
|
||||
// positions: [B, L, 3] with (h, w, t) coordinates
|
||||
// axesDims: [32, 48, 48] - dimensions for each axis
|
||||
// Returns: cos, sin each [B, L, 1, head_dim/2]
|
||||
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
|
||||
// Compute frequencies for each axis
|
||||
// dims = [32, 48, 48], so halves = [16, 24, 24]
|
||||
ropeTheta := float32(256.0)
|
||||
|
||||
freqs := make([]*mlx.Array, 3)
|
||||
for axis := 0; axis < 3; axis++ {
|
||||
half := axesDims[axis] / 2
|
||||
f := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
|
||||
}
|
||||
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
|
||||
}
|
||||
|
||||
// Extract position coordinates
|
||||
shape := positions.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// positions[:, :, 0] -> h positions
|
||||
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
|
||||
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
|
||||
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
|
||||
|
||||
// Compute args: pos * freqs for each axis
|
||||
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
|
||||
posW = mlx.ExpandDims(posW, 3)
|
||||
posT = mlx.ExpandDims(posT, 3)
|
||||
|
||||
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
|
||||
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
|
||||
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
|
||||
|
||||
// Concatenate: [B, L, 1, 16+24+24=64]
|
||||
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
|
||||
|
||||
// Compute cos and sin
|
||||
return mlx.Cos(args), mlx.Sin(args)
|
||||
}
|
||||
|
||||
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
|
||||
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize // H_tok
|
||||
pW := W / patchSize // W_tok
|
||||
|
||||
// Match Python exactly: reshape treating B=1 as part of contiguous data
|
||||
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
|
||||
|
||||
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
|
||||
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
|
||||
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
|
||||
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
|
||||
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
|
||||
|
||||
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
|
||||
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
|
||||
return mlx.Reshape(x, 1, C, H, W)
|
||||
}
|
||||
@@ -1,650 +0,0 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"`
|
||||
OutChannels int32 `json:"out_channels"`
|
||||
LatentChannels int32 `json:"latent_channels"`
|
||||
BlockOutChannels []int32 `json:"block_out_channels"`
|
||||
LayersPerBlock int32 `json:"layers_per_block"`
|
||||
NormNumGroups int32 `json:"norm_num_groups"`
|
||||
ScalingFactor float32 `json:"scaling_factor"`
|
||||
ShiftFactor float32 `json:"shift_factor"`
|
||||
}
|
||||
|
||||
// loadVAEConfig loads VAE config from a JSON file
|
||||
func loadVAEConfig(path string) (*VAEConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg VAEConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// NewGroupNorm creates a group norm layer
|
||||
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
||||
return &GroupNormLayer{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
NumGroups: numGroups,
|
||||
Eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, C, H, W]
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Reshape to [B, groups, C/groups, H, W]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 2, true)
|
||||
mean = mlx.Mean(mean, 3, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
||||
variance = mlx.Mean(variance, 3, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, C, H, W]
|
||||
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
||||
|
||||
// Scale and shift (weight and bias are [C])
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer
|
||||
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
Bias *mlx.Array // [out_channels]
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// NewConv2D creates a Conv2D layer
|
||||
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
|
||||
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
||||
// Transpose weight from OIHW to OHWI
|
||||
// [O, I, H, W] -> [O, H, W, I]
|
||||
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
|
||||
return &Conv2D{
|
||||
Weight: weightOHWI,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies convolution
|
||||
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [N, C, H, W] -> [N, H, W, C]
|
||||
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
||||
|
||||
// Conv in NHWC format
|
||||
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
||||
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer
|
||||
Conv1 *Conv2D
|
||||
Norm2 *GroupNormLayer
|
||||
Conv2 *Conv2D
|
||||
ConvShortcut *Conv2D // nil if in_channels == out_channels
|
||||
}
|
||||
|
||||
// NewResnetBlock2D creates a ResNet block
|
||||
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
|
||||
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
|
||||
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
|
||||
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
|
||||
}
|
||||
|
||||
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
|
||||
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
|
||||
}
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block with staged evaluation
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: norm1
|
||||
{
|
||||
h = rb.Norm1.Forward(x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: silu + conv1
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 3: norm2
|
||||
{
|
||||
prev := h
|
||||
h = rb.Norm2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: silu + conv2
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
{
|
||||
prev := h
|
||||
if rb.ConvShortcut != nil {
|
||||
shortcut := rb.ConvShortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer
|
||||
ToQWeight *mlx.Array
|
||||
ToQBias *mlx.Array
|
||||
ToKWeight *mlx.Array
|
||||
ToKBias *mlx.Array
|
||||
ToVWeight *mlx.Array
|
||||
ToVBias *mlx.Array
|
||||
ToOutWeight *mlx.Array
|
||||
ToOutBias *mlx.Array
|
||||
NumHeads int32
|
||||
}
|
||||
|
||||
// NewVAEAttentionBlock creates an attention block
|
||||
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEAttentionBlock{
|
||||
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
|
||||
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
|
||||
ToQBias: toQBias,
|
||||
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
|
||||
ToKBias: toKBias,
|
||||
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
|
||||
ToVBias: toVBias,
|
||||
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
|
||||
ToOutBias: toOutBias,
|
||||
NumHeads: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies attention with staged evaluation
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: GroupNorm + reshape
|
||||
{
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Transpose(h, 0, 2, 3, 1)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
var out *mlx.Array
|
||||
|
||||
// Stage 2: Q, K, V projections + attention
|
||||
{
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q = mlx.Add(q, ab.ToQBias)
|
||||
k := mlx.Linear(h, ab.ToKWeight)
|
||||
k = mlx.Add(k, ab.ToKBias)
|
||||
v := mlx.Linear(h, ab.ToVWeight)
|
||||
v = mlx.Add(v, ab.ToVBias)
|
||||
h.Free()
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
// Stage 3: Output projection + reshape + residual
|
||||
{
|
||||
prev := out
|
||||
out = mlx.Linear(out, ab.ToOutWeight)
|
||||
out = mlx.Add(out, ab.ToOutBias)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Transpose(out, 0, 3, 1, 2)
|
||||
out = mlx.Add(out, residual)
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// NewUpDecoderBlock2D creates an up decoder block
|
||||
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upsample = NewConv2D(upWeight, upBias, 1, 1)
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block with staged evaluation to reduce peak memory
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
prev := x
|
||||
x = resnet.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// NewVAEMidBlock creates the mid block
|
||||
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the mid block with staged evaluation
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
// Attention handles its own pools
|
||||
prev = x
|
||||
x = mb.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
ConvIn *Conv2D
|
||||
MidBlock *VAEMidBlock
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading VAE decoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = cfg
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load conv_in
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load mid block
|
||||
fmt.Print(" Loading mid block... ")
|
||||
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load up blocks
|
||||
fmt.Print(" Loading up blocks... ")
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numBlocks)
|
||||
|
||||
// Load conv_norm_out
|
||||
fmt.Print(" Loading conv_norm_out... ")
|
||||
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load conv_out
|
||||
fmt.Print(" Loading conv_out... ")
|
||||
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
{
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
h = vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
h = vae.MidBlock.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
{
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
||||
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
||||
func Upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
// Broadcast to [B, C, H, 2, W, 2]
|
||||
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
||||
// Reshape to [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H*2, W*2)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -1,361 +0,0 @@
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Z-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
hTok := latentH / tcfg.PatchSize
|
||||
wTok := latentW / tcfg.PatchSize
|
||||
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
maxLen := posEmb.Shape()[1]
|
||||
if useCFG && negEmb.Shape()[1] > maxLen {
|
||||
maxLen = negEmb.Shape()[1]
|
||||
}
|
||||
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
|
||||
maxLen += pad
|
||||
}
|
||||
|
||||
posEmb = padToLength(posEmb, maxLen)
|
||||
if useCFG {
|
||||
negEmb = padToLength(negEmb, maxLen)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
|
||||
|
||||
// Init latents [B, C, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
|
||||
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
|
||||
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
}
|
||||
}
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
|
||||
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padToLength pads a sequence tensor to the target length by repeating the last token.
|
||||
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
|
||||
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
|
||||
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
baseSeqLen := float32(256)
|
||||
maxSeqLen := float32(4096)
|
||||
baseShift := float32(0.5)
|
||||
maxShift := float32(1.15)
|
||||
|
||||
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
|
||||
b := baseShift - m*baseSeqLen
|
||||
return float32(imgSeqLen)*m + b
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
// Package nn provides neural network layer types.
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// Layer is the interface for neural network layers with a Forward method.
|
||||
type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
Weight *mlx.Array `weight:"weight"` // [out_features, in_features]
|
||||
Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil
|
||||
}
|
||||
|
||||
// NewLinear creates a linear layer.
|
||||
// Weight should be [out_features, in_features].
|
||||
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies the linear transformation: x @ W.T + bias
|
||||
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
w := mlx.Transpose(l.Weight, 1, 0)
|
||||
if l.Bias != nil {
|
||||
return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
|
||||
}
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: l.Bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
// Forward applies the quantized linear transformation.
|
||||
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
if ql.Bias != nil {
|
||||
out = mlx.Add(out, ql.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32 // optional: used if Forward called with eps=0
|
||||
}
|
||||
|
||||
// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
|
||||
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
|
||||
return &RMSNorm{Weight: weight, Eps: eps}
|
||||
}
|
||||
|
||||
// Forward applies RMS normalization. If eps=0, uses stored Eps.
|
||||
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
if eps == 0 {
|
||||
eps = rn.Eps
|
||||
}
|
||||
return mlx.RMSNorm(x, rn.Weight, eps)
|
||||
}
|
||||
|
||||
// Embedding represents an embedding layer.
|
||||
type Embedding struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// NewEmbedding creates an embedding layer.
|
||||
func NewEmbedding(weight *mlx.Array) *Embedding {
|
||||
return &Embedding{Weight: weight}
|
||||
}
|
||||
|
||||
// Forward looks up embeddings by indices.
|
||||
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
return mlx.Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// RepeatKV repeats K/V tensors for grouped query attention
|
||||
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
|
||||
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
|
||||
if repeatFactor == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// Repeat along the new axis
|
||||
reps := []int32{1, 1, repeatFactor, 1, 1}
|
||||
x = mlx.Tile(x, reps)
|
||||
// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
// scores: [B, num_heads, S, S]
|
||||
shape := scores.Shape()
|
||||
seqLen := shape[2]
|
||||
|
||||
// Create causal mask: 1 for positions to keep, 0 for positions to mask
|
||||
mask := mlx.Tri(seqLen, seqLen, 0)
|
||||
|
||||
// Where mask is 0, set score to -inf
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
|
||||
// Broadcast mask to match scores shape
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]
|
||||
|
||||
// Use where: if mask > 0, keep scores, else -inf
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// ApplyCausalMaskWithOffset applies causal mask for cached attention
|
||||
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
|
||||
// offset: the starting position of the new queries (i.e., cache length)
|
||||
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
|
||||
if offset == 0 {
|
||||
return ApplyCausalMask(scores)
|
||||
}
|
||||
|
||||
shape := scores.Shape()
|
||||
queryLen := shape[2]
|
||||
keyLen := shape[3]
|
||||
|
||||
// For cached attention, new queries can attend to all cached keys plus
|
||||
// new keys up to and including their position.
|
||||
mask := mlx.Tri(queryLen, keyLen, int(offset))
|
||||
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]
|
||||
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// LayerNorm represents a standard layer normalization layer (with bias).
|
||||
type LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
|
||||
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
eps := ln.Eps
|
||||
if eps == 0 {
|
||||
eps = 1e-5
|
||||
}
|
||||
// Compute mean and variance along last dimension
|
||||
mean := mlx.Mean(x, -1, true)
|
||||
centered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
|
||||
normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))
|
||||
|
||||
// Scale and shift
|
||||
out := mlx.Mul(normalized, ln.Weight)
|
||||
if ln.Bias != nil {
|
||||
out = mlx.Add(out, ln.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,354 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3, // row 0
|
||||
4, 5, 6, // row 1
|
||||
}, []int32{2, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Input: [1, 3]
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
|
||||
t.Errorf("expected [6, 15], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
|
||||
func TestLinearWithBias(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
}, []int32{2, 3})
|
||||
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
|
||||
mlx.Eval(weight, bias)
|
||||
|
||||
linear := NewLinear(weight, bias)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [6, 15] + [10, 20] = [16, 35]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
|
||||
t.Errorf("expected [16, 35], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearBatched verifies Linear works with batched input.
|
||||
func TestLinearBatched(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 0,
|
||||
0, 1,
|
||||
}, []int32{2, 2}) // Identity
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Batch of 3 inputs
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
}, []int32{3, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Identity should return same values
|
||||
data := out.Data()
|
||||
expected := []float32{1, 2, 3, 4, 5, 6}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNorm verifies RMSNorm computation.
|
||||
func TestRMSNorm(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
// Input with known RMS
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-1.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~1.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
|
||||
func TestRMSNormWithScale(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-2.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~2.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedding verifies embedding lookup.
|
||||
func TestEmbedding(t *testing.T) {
|
||||
// Embedding table: 4 tokens, dim 3
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
0, 0, 0, // token 0
|
||||
1, 1, 1, // token 1
|
||||
2, 2, 2, // token 2
|
||||
3, 3, 3, // token 3
|
||||
}, []int32{4, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Look up tokens [1, 3, 0]
|
||||
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
|
||||
mlx.Eval(indices)
|
||||
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKV verifies K/V repetition for GQA.
|
||||
func TestRepeatKV(t *testing.T) {
|
||||
// [B=1, num_kv_heads=2, S=2, head_dim=2]
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
// head 0
|
||||
1, 2, // pos 0
|
||||
3, 4, // pos 1
|
||||
// head 1
|
||||
5, 6, // pos 0
|
||||
7, 8, // pos 1
|
||||
}, []int32{1, 2, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
// Repeat factor 2: 2 kv heads -> 4 heads
|
||||
out := RepeatKV(x, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
shape := out.Shape()
|
||||
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
|
||||
t.Errorf("expected shape [1,4,2,2], got %v", shape)
|
||||
}
|
||||
|
||||
data := out.Data()
|
||||
// After repeat: head0, head0, head1, head1
|
||||
expected := []float32{
|
||||
1, 2, 3, 4, // head 0 (original)
|
||||
1, 2, 3, 4, // head 0 (repeat)
|
||||
5, 6, 7, 8, // head 1 (original)
|
||||
5, 6, 7, 8, // head 1 (repeat)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
|
||||
func TestRepeatKVNoOp(t *testing.T) {
|
||||
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := RepeatKV(x, 1)
|
||||
// Should return same pointer
|
||||
if out != x {
|
||||
t.Error("RepeatKV with factor 1 should return input unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMask verifies causal masking.
|
||||
func TestApplyCausalMask(t *testing.T) {
|
||||
// [B=1, heads=1, S=3, S=3] - all ones
|
||||
scores := mlx.Ones(1, 1, 3, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMask(scores)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Lower triangular should be 1, upper should be -1e9
|
||||
// Row 0: [1, -inf, -inf]
|
||||
// Row 1: [1, 1, -inf]
|
||||
// Row 2: [1, 1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:3])
|
||||
}
|
||||
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
|
||||
t.Errorf("row 1 wrong: %v", data[3:6])
|
||||
}
|
||||
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
|
||||
t.Errorf("row 2 wrong: %v", data[6:9])
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
|
||||
func TestApplyCausalMaskWithOffset(t *testing.T) {
|
||||
// Simulating: cache has 2 tokens, adding 1 new query
|
||||
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
|
||||
scores := mlx.Ones(1, 1, 1, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// With offset=2, query at position 2 can attend to all 3 positions
|
||||
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
|
||||
t.Errorf("expected [1, 1, 1], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
|
||||
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
|
||||
scores := mlx.Ones(1, 1, 2, 2)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 0)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Standard causal: [1, -inf], [1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:2])
|
||||
}
|
||||
if data[2] != 1 || data[3] != 1 {
|
||||
t.Errorf("row 1 wrong: %v", data[2:4])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearSmall benchmarks small Linear forward pass.
|
||||
func BenchmarkLinearSmall(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{256, 256}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 256}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
|
||||
func BenchmarkLinearLarge(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
|
||||
func BenchmarkRMSNorm(b *testing.B) {
|
||||
weight := mlx.Ones(4096)
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := norm.Forward(x, 0)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEmbedding benchmarks embedding lookup.
|
||||
func BenchmarkEmbedding(b *testing.B) {
|
||||
// Typical vocab size
|
||||
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Single token lookup
|
||||
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
|
||||
mlx.Eval(indices)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRepeatKV benchmarks K/V repetition.
|
||||
func BenchmarkRepeatKV(b *testing.B) {
|
||||
// Typical GQA setup: 8 kv heads -> 32 heads
|
||||
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := RepeatKV(x, 4)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
//
|
||||
// Struct tags use the format: `weight:"path[,optional]"`
|
||||
// - path: the weight name suffix (appended to prefix)
|
||||
// - optional: if present, missing weights don't cause errors
|
||||
// - "-": skip this field entirely
|
||||
// - no tag on struct pointer: recurse with current prefix
|
||||
// - no tag on *mlx.Array: skip (computed fields don't need loading)
|
||||
//
|
||||
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
|
||||
// The slice must be pre-allocated to the correct length.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Attention struct {
|
||||
// QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
// KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
// Cache *mlx.Array // no tag = skipped (computed field)
|
||||
// }
|
||||
//
|
||||
// err := LoadModule(&attn, weights, "model.layers.0")
|
||||
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
|
||||
}
|
||||
|
||||
var errs []string
|
||||
loadStruct(v, weights, prefix, &errs, false)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStruct recursively loads weights into a struct value.
|
||||
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldVal := v.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !fieldVal.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag
|
||||
tag, hasTag := field.Tag.Lookup("weight")
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag options
|
||||
optional := parentOptional
|
||||
weightPath := tag
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
weightPath = tag[:idx]
|
||||
if strings.Contains(tag[idx+1:], "optional") {
|
||||
optional = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build full path
|
||||
fullPath := joinPath(prefix, weightPath)
|
||||
|
||||
// For struct pointers without a tag, recurse with current prefix
|
||||
if !hasTag && fieldVal.Kind() == reflect.Ptr {
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
|
||||
// *mlx.Array - load directly (but skip if no tag - computed fields)
|
||||
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
|
||||
if !hasTag {
|
||||
continue // no tag on *mlx.Array = computed field, skip
|
||||
}
|
||||
arr, err := weights.GetTensor(fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath)
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(arr))
|
||||
continue
|
||||
}
|
||||
|
||||
// Pointer to struct - allocate and recurse
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
if optional && !hasWeightsWithPrefix(weights, fullPath) {
|
||||
continue
|
||||
}
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
|
||||
loadSlice(fieldVal, weights, fullPath, errs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
|
||||
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
for _, name := range weights.ListTensors() {
|
||||
if strings.HasPrefix(name, prefix+".") || name == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadSlice loads weights into each element of a slice of struct pointers.
|
||||
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
|
||||
elemStructType := v.Type().Elem().Elem()
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i)
|
||||
if elem.IsNil() {
|
||||
elem.Set(reflect.New(elemStructType))
|
||||
}
|
||||
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
|
||||
}
|
||||
}
|
||||
|
||||
// joinPath joins path segments with dots, handling empty segments.
|
||||
func joinPath(prefix, suffix string) string {
|
||||
if prefix == "" {
|
||||
return suffix
|
||||
}
|
||||
if suffix == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SafetensorHeader represents the JSON header of a safetensors file
|
||||
type SafetensorHeader map[string]TensorInfo
|
||||
|
||||
// TensorInfo contains metadata about a tensor
|
||||
type TensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// parseSafetensorHeader reads only the JSON header from a safetensors file.
|
||||
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := f.Read(headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header SafetensorHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
delete(header, "__metadata__")
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// dtypeFromString converts safetensors dtype string to mlx.Dtype
|
||||
func dtypeFromString(s string) mlx.Dtype {
|
||||
switch strings.ToUpper(s) {
|
||||
case "F32", "FLOAT32":
|
||||
return mlx.DtypeFloat32
|
||||
case "F16", "FLOAT16":
|
||||
return mlx.DtypeFloat16
|
||||
case "BF16", "BFLOAT16":
|
||||
return mlx.DtypeBFloat16
|
||||
case "I32", "INT32":
|
||||
return mlx.DtypeInt32
|
||||
case "I64", "INT64":
|
||||
return mlx.DtypeInt64
|
||||
case "U8", "UINT8":
|
||||
return mlx.DtypeUint8
|
||||
default:
|
||||
return mlx.DtypeFloat32
|
||||
}
|
||||
}
|
||||
|
||||
// ModelWeights manages weights from multiple safetensor files.
|
||||
type ModelWeights struct {
|
||||
dir string // Model directory
|
||||
tensorFiles map[string]string // tensor name -> file path
|
||||
tensorInfo map[string]TensorInfo // tensor name -> metadata
|
||||
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
|
||||
cache map[string]*mlx.Array // tensor name -> array (after Load)
|
||||
}
|
||||
|
||||
// LoadModelWeights scans safetensor files and builds a tensor index.
|
||||
// This only reads JSON headers, not tensor data.
|
||||
func LoadModelWeights(dir string) (*ModelWeights, error) {
|
||||
mw := &ModelWeights{
|
||||
dir: dir,
|
||||
tensorFiles: make(map[string]string),
|
||||
tensorInfo: make(map[string]TensorInfo),
|
||||
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
|
||||
header, err := parseSafetensorHeader(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
|
||||
}
|
||||
|
||||
for name, info := range header {
|
||||
mw.tensorFiles[name] = path
|
||||
mw.tensorInfo[name] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(mw.tensorFiles) == 0 {
|
||||
return nil, fmt.Errorf("no safetensor files found in %s", dir)
|
||||
}
|
||||
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// Load loads all tensors into cache with the specified dtype.
|
||||
// If dtype is 0, tensors are loaded in their original dtype.
|
||||
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
||||
// or native loading when tensors are already in the target dtype.
|
||||
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
|
||||
if dtype == 0 {
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// Check if any tensor needs conversion
|
||||
needsConversion := false
|
||||
for name := range mw.tensorFiles {
|
||||
info := mw.tensorInfo[name]
|
||||
if dtypeFromString(info.Dtype) != dtype {
|
||||
needsConversion = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsConversion {
|
||||
return mw.loadStreaming(dtype)
|
||||
}
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// loadNative loads all tensors using the native memory-mapped loader.
|
||||
func (mw *ModelWeights) loadNative() error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
arr := native.Get(name)
|
||||
if arr == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
mw.cache[name] = arr
|
||||
}
|
||||
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStreaming loads tensors with dtype conversion.
|
||||
// Uses the same pattern as Python: replace each entry in the map after conversion,
|
||||
// so the original tensor loses its reference and can be freed.
|
||||
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
src := native.Get(name)
|
||||
if src == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
|
||||
dst := mlx.AsType(src, dtype)
|
||||
mlx.Eval(dst)
|
||||
native.Set(name, dst)
|
||||
mw.cache[name] = dst
|
||||
}
|
||||
|
||||
native.Free()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a tensor from cache. Call Load() first.
|
||||
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
|
||||
if mw.cache == nil {
|
||||
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
||||
}
|
||||
arr, ok := mw.cache[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found in cache", name)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// GetTensor loads a tensor using the native loader without caching.
|
||||
// For bulk loading, use Load() + Get() instead.
|
||||
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
if mw.cache != nil {
|
||||
if arr, ok := mw.cache[name]; ok {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
|
||||
path, ok := mw.tensorFiles[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
|
||||
native, ok := mw.nativeCache[path]
|
||||
if !ok {
|
||||
var err error
|
||||
native, err = mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return native.Get(name), nil
|
||||
}
|
||||
|
||||
// GetTensorInfo returns metadata about a tensor without loading it.
|
||||
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
|
||||
info, ok := mw.tensorInfo[name]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names.
|
||||
func (mw *ModelWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensorFiles))
|
||||
for name := range mw.tensorFiles {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensorFiles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
native.Free()
|
||||
delete(mw.nativeCache, path)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestLoadModelWeights(t *testing.T) {
|
||||
// Skip if no model available
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Check we found tensors
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Fatal("no tensors found")
|
||||
}
|
||||
t.Logf("found %d tensors", len(tensors))
|
||||
|
||||
// Check HasTensor
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Errorf("HasTensor(%q) = false", tensors[0])
|
||||
}
|
||||
if mw.HasTensor("nonexistent.weight") {
|
||||
t.Error("HasTensor returned true for nonexistent tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Skip("no tensors")
|
||||
}
|
||||
|
||||
// Load first tensor
|
||||
arr, err := mw.GetTensor(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
|
||||
}
|
||||
|
||||
// Verify it has a shape
|
||||
shape := arr.Shape()
|
||||
if len(shape) == 0 {
|
||||
t.Error("tensor has no shape")
|
||||
}
|
||||
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLoadWithDtype(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Load all tensors as bfloat16
|
||||
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
// Get a tensor from cache
|
||||
tensors := mw.ListTensors()
|
||||
arr, err := mw.Get(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
|
||||
// Verify dtype (unless it was already bf16)
|
||||
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLookupTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// HasTensor returns false for nonexistent
|
||||
if mw.HasTensor("nonexistent") {
|
||||
t.Error("HasTensor should return false for nonexistent")
|
||||
}
|
||||
|
||||
// HasTensor returns true for existing tensor
|
||||
tensors := mw.ListTensors()
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Error("HasTensor should return true for existing tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorHeader(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
// Find a safetensors file
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var stFile string
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".safetensors" {
|
||||
stFile = filepath.Join(modelDir, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if stFile == "" {
|
||||
t.Skip("no safetensors file found")
|
||||
}
|
||||
|
||||
header, err := parseSafetensorHeader(stFile)
|
||||
if err != nil {
|
||||
t.Fatalf("parseSafetensorHeader: %v", err)
|
||||
}
|
||||
|
||||
if len(header) == 0 {
|
||||
t.Error("header is empty")
|
||||
}
|
||||
|
||||
// Check a tensor has valid info
|
||||
for name, info := range header {
|
||||
if info.Dtype == "" {
|
||||
t.Errorf("%s: empty dtype", name)
|
||||
}
|
||||
if len(info.Shape) == 0 {
|
||||
t.Errorf("%s: empty shape", name)
|
||||
}
|
||||
break // just check one
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
# Tokenizer
|
||||
|
||||
Fast, correct tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms.
|
||||
|
||||
## Features
|
||||
|
||||
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
|
||||
- **SentencePiece** - Gemma style with `▁` space handling
|
||||
- **WordPiece** - BERT style with `##` continuation tokens
|
||||
- **Parallel encoding** - Automatic parallelization for inputs >4KB
|
||||
- **HuggingFace compatible** - Loads `tokenizer.json` directly
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
// Load from HuggingFace model directory
|
||||
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Encode text to token IDs
|
||||
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
|
||||
|
||||
// Decode back to text
|
||||
text := tok.Decode(ids)
|
||||
|
||||
// Check special tokens
|
||||
if tok.IsEOS(ids[len(ids)-1]) {
|
||||
// End of sequence
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Benchmarks on Apple M3 Max:
|
||||
|
||||
| Input Size | Encode | Decode | Tokens |
|
||||
|------------|--------|--------|--------|
|
||||
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
|
||||
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
|
||||
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
|
||||
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
|
||||
|
||||
Comparison with other implementations (10 MB input):
|
||||
|
||||
| Implementation | Encode Speed | Notes |
|
||||
|----------------|--------------|-------|
|
||||
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
|
||||
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
|
||||
| llama.cpp (C++) | ~2 MB/s | Single-threaded only |
|
||||
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
|
||||
|
||||
## Correctness
|
||||
|
||||
The tokenizer matches HuggingFace transformers exactly. Verified with:
|
||||
|
||||
- 82 rigorous test cases for Gemma (SentencePiece)
|
||||
- 458 fuzz test cases covering Unicode edge cases
|
||||
- Full 0x00-0xFF byte roundtrip for BPE
|
||||
|
||||
Run tests:
|
||||
```bash
|
||||
go test ./tokenizer/... -v
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Load(path)
|
||||
│
|
||||
├─ tokenizer.json → loadFromTokenizerJSON()
|
||||
│ ├─ Parse vocab, merges, added_tokens
|
||||
│ ├─ Detect type: BPE / SentencePiece / WordPiece
|
||||
│ ├─ Compile pretokenizer regex (BPE only)
|
||||
│ └─ Load special tokens from config files
|
||||
│
|
||||
└─ vocab.json + merges.txt → loadVocabMerges()
|
||||
|
||||
Encode(text)
|
||||
│
|
||||
├─ Split by special tokens
|
||||
├─ Apply pretokenizer regex (BPE) or space→▁ (SentencePiece)
|
||||
├─ For each chunk:
|
||||
│ ├─ Fast path: single token lookup
|
||||
│ └─ Slow path: BPE merge algorithm
|
||||
└─ Parallel for inputs >4KB
|
||||
|
||||
Decode(ids)
|
||||
│
|
||||
├─ Look up each token
|
||||
└─ Apply inverse transform:
|
||||
├─ BPE: byte-level decode (0x0100 → 0x00, etc.)
|
||||
├─ SentencePiece: ▁→space, <0xNN>→byte
|
||||
└─ WordPiece: strip ## prefix
|
||||
```
|
||||
|
||||
## Key Implementation Details
|
||||
|
||||
### BPE Byte-Level Encoding
|
||||
|
||||
GPT-2 style encoding maps bytes to Unicode codepoints to handle arbitrary binary data:
|
||||
|
||||
```go
|
||||
// Precomputed table: byte → rune
|
||||
var byteToRune [256]rune // 0x00→0x0100, 0x20→0x0120, etc.
|
||||
```
|
||||
|
||||
### Pretokenizer Regex
|
||||
|
||||
HuggingFace patterns use PCRE features not supported by Go's RE2. We rewrite:
|
||||
|
||||
```go
|
||||
// PCRE (HuggingFace)
|
||||
`\s+(?!\S)|\s+`
|
||||
|
||||
// RE2 (Go) - with post-processing for whitespace boundaries
|
||||
`\s+`
|
||||
```
|
||||
|
||||
### Special Token Handling
|
||||
|
||||
Special tokens are matched greedily (longest first) before pretokenization:
|
||||
|
||||
```go
|
||||
// Sorted by length, checked with HasPrefix
|
||||
for _, tok := range sortedSpecialTokens {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
// Found special token
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Opportunities
|
||||
|
||||
Potential optimizations not yet implemented:
|
||||
|
||||
| Optimization | Expected Gain | Complexity |
|
||||
|--------------|---------------|------------|
|
||||
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
|
||||
| Custom regex engine (like tiktoken) | 1.5-2x | High |
|
||||
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
|
||||
| Assembly BPE merge loop | 1.2-1.5x | High |
|
||||
| Memoization for repeated substrings | Variable | Low |
|
||||
|
||||
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
|
||||
|
||||
## Not Yet Implemented
|
||||
|
||||
| Feature | Used By | Notes |
|
||||
|---------|---------|-------|
|
||||
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
|
||||
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
|
||||
| Custom pretokenizers | Model-specific | Beyond standard patterns |
|
||||
|
||||
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `tokenizer.go` | Main implementation (~1000 lines) |
|
||||
| `tokenizer_test.go` | Tests and benchmarks |
|
||||
| `testdata/` | Mini tokenizer for unit tests |
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user