Compare commits

..

14 Commits

Author SHA1 Message Date
ParthSareen
6b7456ca1f wip 2026-01-07 01:58:37 -08:00
ParthSareen
44179b7e53 x/agent: use stdlib path package for path normalization
Replace custom normalizePath function with stdlib path.Clean.
Use path.IsAbs and path.Dir for cleaner, more robust code.
Add sibling escape detection to prevent traversal attacks like
"tools/a/b/../../../etc" which normalizes to "etc" (a sibling).
2026-01-06 18:09:10 -08:00
ParthSareen
359be5b658 x/cmd: handle 500 errors by informing model and retrying
When server returns a 500 error (often due to tool parsing failures),
instead of failing, send the error message and the model's response
back to the model so it can learn and retry.

- Includes both error message and model's failed response
- Limits to 3 consecutive retries to prevent infinite loops
- Resets retry counter on successful responses
2026-01-06 16:55:08 -08:00
ParthSareen
820e51e144 x/cmd: add --yolo/-y flag to skip tool approval prompts
Add a -y/--yolo flag that skips all interactive tool approval prompts.
Dangerous command patterns (rm -rf, sudo, etc.) are still blocked.

Usage: ollama run model --experimental -y
2026-01-06 16:47:26 -08:00
ParthSareen
8470c25fa9 x/cmd: handle 401 from Chat API with sign-in prompt
When client.Chat() returns a 401 AuthorizationError, prompt the user
to sign in instead of just showing "Error: 401 Unauthorized".

This handles the case where users need to authenticate to use cloud
models, not just web search.
2026-01-06 15:43:11 -08:00
ParthSareen
c8b599bd44 x/agent: fix path traversal vulnerability in hierarchical prefix matching
Reject any path containing ".." from creating allowlist prefixes.
This prevents attacks where approving "cat tools/file.go" would allow
"cat tools/../../etc/passwd" via the hierarchical prefix matching.

Commands with ".." now require individual approval each time.
Also reject absolute paths from prefix creation.

Added tests for path traversal scenarios.
2026-01-06 15:41:57 -08:00
ParthSareen
59928c536b x/cmd: add context-aware tool output truncation for LLM
Implement dual-limit tool output truncation to prevent context overflow:
- 4k tokens (~16k chars) for local models on local servers
- 10k tokens (~40k chars) for cloud models or remote servers

This helps preserve context window for local models with smaller
context windows while allowing larger outputs for cloud services.
2026-01-06 15:36:03 -08:00
ParthSareen
0b4850812f x/agent: fix hierarchical prefix matching for Windows paths
Normalize backslashes to forward slashes in extractBashPrefix to ensure
consistent cross-platform behavior. Use string-based path splitting
instead of filepath.Dir to avoid platform-specific behavior.

Add cross-platform test for Windows-style backslash paths.
2026-01-06 15:16:28 -08:00
ParthSareen
9383082070 x: add tests for tool disabling, auth error, and helper functions
- Add tests for OLLAMA_AGENT_DISABLE_WEBSEARCH/BASH env vars
- Add tests for ErrWebSearchAuthRequired error type
- Add tests for isLocalModel, isLocalServer, truncateToolOutputForLocalModel
2026-01-06 14:51:27 -08:00
ParthSareen
85e48af46a x/cmd: add tool output toggle and interactive signin flow
- Add Ctrl+O toggle to expand/collapse tool output inline
- Show tools available in grey text at startup
- Add interactive signin flow when web search returns 401:
  prompts user, shows signin URL, polls until auth completes
- Truncate tool output for local models to prevent context overflow
- Update help text with Ctrl+O keyboard shortcut
2026-01-06 14:48:03 -08:00
ParthSareen
aa9a1477b3 x/agent: improve approval UX with hierarchical matching and signin prompt
- Add hierarchical prefix matching for bash commands: if "cat:tools/"
  is approved, subdirectories like "cat:tools/subdir/" are also allowed
- Show "Uses internet via ollama.com" notice in web_search approval popup
- Add PromptYesNo function for interactive yes/no prompts
- Add tests for hierarchical prefix matching
2026-01-06 14:47:22 -08:00
ParthSareen
aed714a676 x/tools: use Ollama key signing for web search authentication
Replace OLLAMA_API_KEY environment variable with Ollama's native key
signing mechanism (~/.ollama/id_ed25519). Add ErrWebSearchAuthRequired
error type for handling 401 responses.
2026-01-06 14:45:08 -08:00
ParthSareen
064c6a984e x/tools: add environment variables to disable tools
Add OLLAMA_AGENT_DISABLE_WEBSEARCH and OLLAMA_AGENT_DISABLE_BASH
environment variables to selectively disable tools in the agent loop.
2026-01-06 14:44:18 -08:00
ParthSareen
3aaa8d5564 readline: add Ctrl+O support for expanding tool output
Add CharCtrlO constant and ErrExpandOutput error to enable Ctrl+O
as a keyboard shortcut for expanding truncated tool output in the
agent loop.
2026-01-06 14:44:04 -08:00
176 changed files with 714 additions and 37116 deletions

View File

@@ -68,7 +68,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -393,13 +392,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -532,7 +531,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -163,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,36 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -181,7 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -200,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -46,8 +46,6 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -98,10 +96,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -463,15 +457,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -535,7 +520,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
yoloMode, _ := cmd.Flags().GetBool("yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -837,11 +822,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -1785,10 +1765,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
}
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,6 +3,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -53,7 +55,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -1,406 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -32,9 +32,7 @@
"codeblocks": "system"
},
"contextual": {
"options": [
"copy"
]
"options": ["copy"]
},
"navbar": {
"links": [
@@ -54,9 +52,7 @@
"display": "simple"
},
"examples": {
"languages": [
"curl"
]
"languages": ["curl"]
}
},
"redirects": [
@@ -101,7 +97,6 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -144,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -1,69 +0,0 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}

View File

@@ -1,149 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -1,584 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}

View File

@@ -21,7 +21,6 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -1,33 +0,0 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -6,6 +6,9 @@ import (
var ErrInterrupt = errors.New("Interrupt")
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
var ErrExpandOutput = errors.New("ExpandOutput")
type InterruptError struct {
Line []rune
}

View File

@@ -206,6 +206,9 @@ func (i *Instance) Readline() (string, error) {
buf.DeleteBefore()
case CharCtrlL:
buf.ClearScreen()
case CharCtrlO:
// Ctrl+O - expand tool output
return "", ErrExpandOutput
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:

View File

@@ -3,7 +3,6 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -12,19 +11,12 @@ func Execute(args []string) error {
}
var newRunner bool
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
if args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
_build_darwin() {
for ARCH in $ARCHS; do
status "Building darwin $ARCH"
INSTALL_PREFIX=dist/darwin-$ARCH/
INSTALL_PREFIX=dist/darwin-$ARCH/
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
if [ "$ARCH" = "amd64" ]; then
status "Building darwin $ARCH dynamic backends"
BUILD_DIR=build/darwin-$ARCH
cmake -B $BUILD_DIR \
cmake -B build/darwin-$ARCH \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
-DMLX_ENGINE=ON \
-DMLX_ENABLE_X64_MAC=ON \
-DOLLAMA_RUNNER_DIR=./
cmake --build $BUILD_DIR --target ggml-cpu -j
cmake --build $BUILD_DIR --target mlx mlxc -j
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
-DOLLAMA_RUNNER_DIR=./ \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
cmake --build build/darwin-$ARCH --target ggml-cpu -j
cmake --install build/darwin-$ARCH --component CPU
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,12 +61,10 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -154,23 +131,17 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -178,7 +149,7 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then

View File

@@ -12,17 +12,6 @@ set -eu
. $(dirname $0)/env.sh
# Check for required tools
if ! command -v zstd >/dev/null 2>&1; then
echo "ERROR: zstd is required but not installed." >&2
echo "Please install zstd:" >&2
echo " - macOS: brew install zstd" >&2
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
echo " - Arch: sudo pacman -S zstd" >&2
exit 1
fi
mkdir -p dist
docker buildx build \
@@ -48,68 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
fi

View File

@@ -66,36 +66,6 @@ if [ -n "$NEEDS" ]; then
exit 1
fi
# Function to download and extract with fallback from zst to tgz
download_and_extract() {
local url_base="$1"
local dest_dir="$2"
local filename="$3"
# Check if .tar.zst is available
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
# zst file exists - check if we have zstd tool
if ! available zstd; then
error "This version requires zstd for extraction. Please install zstd and try again:
- Debian/Ubuntu: sudo apt-get install zstd
- RHEL/CentOS/Fedora: sudo dnf install zstd
- Arch: sudo pacman -S zstd"
fi
status "Downloading ${filename}.tar.zst"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
return 0
fi
# Fall back to .tgz for older versions
status "Downloading ${filename}.tgz"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "${dest_dir}"
}
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
@@ -108,7 +78,10 @@ fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
@@ -118,9 +91,15 @@ fi
# Check for NVIDIA JetPack systems with additional downloads
if [ -f /etc/nv_tegra_release ] ; then
if grep R36 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
status "Downloading JetPack 6 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
status "Downloading JetPack 5 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
else
warning "Unsupported JetPack version detected. GPU may not be supported"
fi
@@ -243,7 +222,10 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
status "Downloading Linux ROCm ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
install_success
status "AMD GPU ready."

View File

@@ -26,7 +26,6 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -455,7 +454,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -74,11 +73,6 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -561,24 +555,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -644,15 +620,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -667,6 +634,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -675,11 +643,13 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -687,11 +657,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
@@ -725,148 +690,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)

View File

@@ -47,15 +47,6 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,14 +13,9 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -50,8 +50,6 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -164,29 +162,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -214,12 +189,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -1575,12 +1544,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -22,7 +22,6 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -42,7 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -21,7 +21,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -195,14 +194,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,48 +543,6 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -6,7 +6,6 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -17,7 +16,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -806,61 +804,3 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}

View File

@@ -381,28 +381,6 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -418,11 +396,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -435,7 +413,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments templateArgs
Arguments map[string]any
}
// templateMessage is a template-compatible representation of api.Message
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
Properties: tool.Function.Parameters.Properties.ToMap(),
},
},
}
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
Arguments: tc.Function.Arguments.ToMap(),
},
})
}

View File

@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -3,13 +3,12 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
)
func (c Capability) String() string {

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -33,29 +33,10 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Allow for this session",
"2. Always allow",
"3. Deny",
}
// toolDisplayNames maps internal tool names to human-readable display names.
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
}
// ToolDisplayName returns the human-readable display name for a tool.
func ToolDisplayName(toolName string) string {
if displayName, ok := toolDisplayNames[toolName]; ok {
return displayName
}
// Default: capitalize first letter and replace underscores with spaces
name := strings.ReplaceAll(toolName, "_", " ")
if len(name) > 0 {
return strings.ToUpper(name[:1]) + name[1:]
}
return toolName
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
@@ -494,32 +475,16 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
if prefix := extractBashPrefix(cmd); prefix != "" {
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
}
}
}
isWarning = isCommandOutsideCwd(cmd)
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -544,12 +509,11 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
displayName := ToolDisplayName(toolName)
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
@@ -558,7 +522,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// For web search, show query and internet notice
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
@@ -566,7 +530,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
@@ -583,28 +547,24 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
}
// Get terminal size
@@ -764,7 +724,7 @@ func wrapText(text string, maxWidth int) []string {
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
@@ -774,70 +734,86 @@ func getHintLines(state *selectorState) []string {
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
warningLines := 0
if state.isWarning {
warningLines = 2 // warning line + blank line after
warningLines = 1
}
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the selector (minimal, no box)
// renderSelectorBox renders the complete selector box
func renderSelectorBox(state *selectorState) {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// Draw warning line if needed
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
boxColor = "\033[91m" // bright red
}
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
}
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
}
// Draw tool info (plain white)
// Draw tool info
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
}
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 {
if i == 2 { // Deny option - show with reason input beside it
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line before hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw hint (dark grey)
// Draw hint (may be multiple lines)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
@@ -849,39 +825,50 @@ func renderSelectorBox(state *selectorState) {
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + numOptions
// (hint lines - 1) + 1 (bottom border) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 {
if i == 2 { // Deny option
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -895,26 +882,36 @@ func updateSelectorOptions(state *selectorState) {
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -938,10 +935,11 @@ func clearSelectorBox(state *selectorState) {
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
var input string
fmt.Scanln(&input)
@@ -984,16 +982,19 @@ func (a *ApprovalManager) AllowedTools() []string {
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var label string
displayName := ToolDisplayName(toolName)
var status string
var icon string
switch result.Decision {
case ApprovalOnce:
label = "Approved"
status = "Approved"
icon = "\033[32m✓\033[0m"
case ApprovalAlways:
label = "Always allowed"
status = "Always allowed"
icon = "\033[32m✓\033[0m"
case ApprovalDeny:
label = "Denied"
status = "Denied"
icon = "\033[31m✗\033[0m"
}
// Format based on tool type
@@ -1003,7 +1004,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
}
}
@@ -1013,11 +1014,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
}
// FormatDenyResult returns the tool result message when a tool is denied.
@@ -1048,14 +1049,15 @@ func PromptYesNo(question string) (bool, error) {
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "%s ", question)
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
}
}
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
}
renderYesNo()
@@ -1102,3 +1104,108 @@ func PromptYesNo(question string) (bool, error) {
}
}
}
// CloudModelOption represents a suggested cloud model for the selection prompt.
type CloudModelOption struct {
Name string
Description string
}
// PromptModelChoice displays a model selection prompt with multiple options.
// Returns the selected model name, or empty string if user declined or cancelled.
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return "", err
}
defer term.Restore(fd, oldState)
// Build options: models + "No thanks, continue"
optionCount := len(models) + 1
selected := 0
// Total lines: question + models + "no thanks" + hint = optionCount + 2
totalLines := optionCount + 2
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
firstRender := true
render := func() {
if !firstRender {
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
}
firstRender = false
// \r\n needed in raw mode for proper line breaks
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
for i, model := range models {
fmt.Fprintf(os.Stderr, "\033[K")
if i == selected {
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
} else {
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
}
}
fmt.Fprintf(os.Stderr, "\033[K")
if selected == len(models) {
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
} else {
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
}
render()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return "", err
}
if n == 1 {
switch buf[0] {
case 'j', 'J':
if selected < optionCount-1 {
selected++
}
render()
case 'k', 'K':
if selected > 0 {
selected--
}
render()
case '\r', '\n':
fmt.Fprintf(os.Stderr, "\n")
if selected < len(models) {
return models[selected].Name, nil
}
return "", nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\n")
return "", nil
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
switch buf[2] {
case 'A': // Up
if selected > 0 {
selected--
}
render()
case 'B': // Down
if selected < optionCount-1 {
selected++
}
render()
}
}
}
}

25
x/agent/prompt_test.go Normal file
View File

@@ -0,0 +1,25 @@
package agent
import (
"testing"
)
func TestCloudModelOptionStruct(t *testing.T) {
// Test that the struct is defined correctly
models := []CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
if len(models) != 2 {
t.Errorf("expected 2 models, got %d", len(models))
}
if models[0].Name != "glm-4.7:cloud" {
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
}
if models[1].Description != "Qwen3 Coder 480B" {
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
}
}

41
x/cmd/cloudmodel_test.go Normal file
View File

@@ -0,0 +1,41 @@
package cmd
import (
"errors"
"testing"
)
func TestCloudModelSwitchRequest(t *testing.T) {
// Test the error type
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
// Test Error() method
errMsg := req.Error()
expected := "switch to model: glm-4.7:cloud"
if errMsg != expected {
t.Errorf("expected %q, got %q", expected, errMsg)
}
// Test errors.As
var err error = req
var switchReq *CloudModelSwitchRequest
if !errors.As(err, &switchReq) {
t.Error("errors.As should return true for CloudModelSwitchRequest")
}
if switchReq.Model != "glm-4.7:cloud" {
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
}
}
func TestSuggestedCloudModels(t *testing.T) {
// Verify the suggested models are defined
if len(suggestedCloudModels) == 0 {
t.Error("suggestedCloudModels should not be empty")
}
// Check first model
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
}
}

View File

@@ -37,6 +37,22 @@ const (
charsPerToken = 4
)
// suggestedCloudModels are the models suggested to users after signing in.
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
var suggestedCloudModels = []agent.CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
type CloudModelSwitchRequest struct {
Model string
}
func (c *CloudModelSwitchRequest) Error() string {
return fmt.Sprintf("switch to model: %s", c.Model)
}
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
@@ -91,8 +107,8 @@ func waitForOllamaSignin(ctx context.Context) error {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
@@ -106,7 +122,7 @@ func waitForOllamaSignin(ctx context.Context) error {
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
return nil
}
// Still waiting, show dot
@@ -119,6 +135,21 @@ func waitForOllamaSignin(ctx context.Context) error {
return nil
}
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
// Returns the selected model name, or empty string if user declines.
func promptCloudModelSuggestion() string {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
fmt.Fprintf(os.Stderr, "\n")
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
if err != nil || selectedModel == "" {
return ""
}
return selectedModel
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -137,6 +168,47 @@ type RunOptions struct {
// YoloMode skips all tool approval prompts
YoloMode bool
// LastToolOutput stores the full output of the last tool execution
// for Ctrl+O expansion. Updated by Chat(), read by caller.
LastToolOutput *string
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
// ActiveModel points to the current model name - can be updated mid-turn
// for model switching. If nil, opts.Model is used.
ActiveModel *string
}
// getActiveModel returns the current model name, checking ActiveModel pointer first.
func getActiveModel(opts *RunOptions) string {
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
return *opts.ActiveModel
}
return opts.Model
}
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
func showModelConnection(ctx context.Context, modelName string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return err
}
if info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
// Chat runs an agent chat loop with tool support.
@@ -236,7 +308,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Model: getActiveModel(&opts),
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
@@ -260,17 +332,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
continue // Retry the loop
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" {
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
@@ -283,11 +358,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
@@ -353,8 +428,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
@@ -364,11 +439,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check if command is auto-allowed (safe command)
// TODO(parthsareen): re-enable with tighter scoped allowlist
// if agent.IsAutoAllowed(cmd) {
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
// skipApproval = true
// }
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
}
}
@@ -376,7 +450,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
@@ -406,22 +480,23 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" && opts.ActiveModel != nil {
*opts.ActiveModel = suggestedModel
showModelConnection(ctx, suggestedModel)
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
@@ -429,7 +504,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
}
}
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
@@ -440,17 +515,27 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
toolSuccess:
// Display tool output (truncated for display)
truncatedOutput := ""
if toolResult != "" {
output := toolResult
if len(output) > 300 {
output = output[:300] + "... (truncated)"
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
}
truncatedOutput = output
// Show result in grey, indented
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Store full and truncated output for Ctrl+O toggle
if opts.LastToolOutput != nil {
*opts.LastToolOutput = toolResult
}
if opts.LastToolOutputTruncated != nil {
*opts.LastToolOutputTruncated = truncatedOutput
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
toolResults = append(toolResults, api.Message{
Role: "tool",
@@ -500,18 +585,17 @@ func truncateUTF8(s string, limit int) string {
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
displayName := agent.ToolDisplayName(toolName)
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
}
}
return displayName
return toolName
}
// Helper types and functions for display
@@ -610,25 +694,28 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return out
}
// checkModelCapabilities checks if the model supports tools.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
// checkModelCapabilities checks if the model supports tools and thinking.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return false, err
return false, false, err
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return false, err
return false, false, err
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityTools {
return true, nil
supportsTools = true
}
if cap == model.CapabilityThinking {
supportsThinking = true
}
}
return false, nil
return supportsTools, supportsThinking, nil
}
// GenerateInteractive runs an interactive agent session.
@@ -648,28 +735,29 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
// Check if model supports tools and thinking
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
supportsThinking = false
}
// Track if session is using thinking mode
usingThinking := think != nil && supportsThinking
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
fmt.Fprintln(os.Stderr)
if toolRegistry.Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
@@ -678,6 +766,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
// Track last tool output for Ctrl+O toggle
var lastToolOutput string
var lastToolOutputTruncated string
var toolOutputExpanded bool
for {
line, err := scanner.Readline()
switch {
@@ -690,6 +783,20 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
sb.Reset()
continue
case errors.Is(err, readline.ErrExpandOutput):
// Ctrl+O pressed - toggle between expanded and collapsed tool output
if lastToolOutput == "" {
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
} else if toolOutputExpanded {
// Currently expanded, show truncated
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
toolOutputExpanded = false
} else {
// Currently collapsed, show full
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
toolOutputExpanded = true
}
continue
case err != nil:
return err
}
@@ -726,26 +833,44 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
toolOutputExpanded = false
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
}
retryChat:
for {
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
ActiveModel: &modelName,
}
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
return err
}
if assistant != nil {
messages = append(messages, *assistant)
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
var switchReq *CloudModelSwitchRequest
if errors.As(err, &switchReq) {
newModel := switchReq.Model
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
}
continue retryChat
}
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
break retryChat
}
sb.Reset()
@@ -753,6 +878,52 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
}
// switchToModel handles model switching with capability checks and UI updates.
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
if capErr != nil {
return fmt.Errorf("could not check model capabilities: %w", capErr)
}
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
if usingThinking && !newSupportsThinking {
return fmt.Errorf("%s does not support thinking mode", newModel)
}
// Show "Connecting to X on ollama.com" for cloud models
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
if err == nil && info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
*modelName = newModel
*supportsTools = newSupportsTools
*supportsThinking = newSupportsThinking
if *supportsTools {
if *toolRegistry == nil {
*toolRegistry = tools.DefaultRegistry()
}
if (*toolRegistry).Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
}
} else {
*toolRegistry = nil
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
return nil
}
// showToolsStatus displays the current tools and approval status.
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
if !supportsTools || registry == nil {

38
x/imagegen/.gitignore vendored
View File

@@ -1,38 +0,0 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

View File

@@ -1,236 +0,0 @@
# Image Generation in Ollama (Experimental)
Generate images from text prompts using local AI models.
## Quick Start
```bash
# Run with a prompt
ollama run z-image "a sunset over mountains"
Generating: step 30/30
Image saved to: /tmp/ollama-image-1704067200.png
```
On macOS, the generated image will automatically open in Preview.
## Supported Models
| Model | VRAM Required | Notes |
|-------|---------------|-------|
| z-image | ~12GB | Based on Flux architecture |
## CLI Usage
```bash
# Generate an image
ollama run z-image "a cat playing piano"
# Check if model is running
ollama ps
# Stop the model
ollama stop z-image
```
## API
### OpenAI-Compatible Endpoint
```bash
POST /v1/images/generations
```
**Request:**
```json
{
"model": "z-image",
"prompt": "a sunset over mountains",
"size": "1024x1024",
"response_format": "b64_json"
}
```
**Response:**
```json
{
"created": 1704067200,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
### Example: cURL
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}'
```
### Example: Save to File
```bash
curl -s http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
```
### Streaming Progress
Enable streaming to receive progress updates via SSE:
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
```
Events:
```
event: progress
data: {"step": 1, "total": 30}
event: progress
data: {"step": 2, "total": 30}
...
event: done
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
```
## Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| model | string | required | Model name |
| prompt | string | required | Text description of image |
| size | string | "1024x1024" | Image dimensions (WxH) |
| n | int | 1 | Number of images (currently only 1 supported) |
| response_format | string | "b64_json" | "b64_json" or "url" |
| stream | bool | false | Enable progress streaming |
## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
- Sufficient VRAM (see model table above)
- Ollama built with MLX support
## Limitations
- macOS only (uses MLX backend)
- Single image per request
- Fixed step count (30 steps)
- Modelfiles not yet supported (use `ollama create` from model directory)
---
# Tensor Model Storage Format
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
## Manifest Structure
The manifest follows the standard ollama format with tensor-specific layer metadata:
```json
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:...", "size": 1234 },
"layers": [
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:25b36eed...",
"size": 49807448,
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
"dtype": "BF16",
"shape": [2560, 9728]
},
{
"mediaType": "application/vnd.ollama.image.json",
"digest": "sha256:abc123...",
"size": 512,
"name": "text_encoder/config.json"
}
]
}
```
Each tensor layer includes:
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
- `dtype`: Data type (BF16, F32, etc.)
- `shape`: Tensor dimensions
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
## Blob Format
Each tensor blob is a minimal safetensors file:
```
[8 bytes: header size (uint64 LE)]
[~80 bytes: JSON header, padded to 8-byte alignment]
[N bytes: raw tensor data]
```
Header contains a single tensor named `"data"`:
```json
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
```
## Why Include the Header?
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
1. **Uses mmap** - Maps file directly into memory, no copies
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
3. **No custom code** - Standard MLX API, battle-tested
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
## Why Per-Tensor Blobs?
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
```
~/.ollama/models/
blobs/
sha256-25b36eed... <- shared by both models
sha256-abc123...
manifests/
library/model-a/latest <- references sha256-25b36eed
library/model-b/latest <- references sha256-25b36eed
```
## Import Flow
```
cd ./weights/Z-Image-Turbo
ollama create z-image
1. Scan component directories (text_encoder/, transformer/, vae/)
2. For each .safetensors file:
- Extract individual tensors
- Wrap each in minimal safetensors format (88B header + data)
- Write to blob store (SHA256 content-addressed)
- Add layer entry to manifest with path-style name
3. Copy config files (*.json) as config layers
4. Write manifest
```

View File

@@ -1,235 +0,0 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imagePath, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imagePath, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imagePath, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imagePath == "" {
return resp
}
if format == "url" {
resp.Data[0].URL = "file://" + imagePath
} else {
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

View File

@@ -1,31 +0,0 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -1,156 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,164 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}

View File

@@ -1,539 +0,0 @@
// cli.go provides CLI commands for image generation models.
//
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
// Currently these are separate to keep experimental code isolated.
package imagegen
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
)
// ImageGenOptions holds options for image generation.
// These can be set via environment variables or interactive commands.
type ImageGenOptions struct {
Width int
Height int
Steps int
Seed int
NegativePrompt string
}
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 9,
Seed: 0, // 0 means random
}
}
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
architecture = index.Architecture
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
}
return fmt.Sprintf("%d", count)
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
cmd.Flags().MarkHidden("seed")
cmd.Flags().MarkHidden("negative")
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
opts.Width = v
}
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
opts.Height = v
}
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
opts.Steps = v
}
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
opts.Seed = v
}
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
opts.NegativePrompt = v
}
}
if interactive {
return runInteractive(cmd, name, keepAlive, opts)
}
// One-shot generation
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
}
// generateImageWithOptions generates an image with the given options.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
})
p.Stop()
if err != nil {
return err
}
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
}
return nil
}
// runInteractive runs an interactive REPL for image generation.
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
continue
case err != nil:
return err
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Handle commands
switch {
case strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts)
continue
case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
case strings.HasPrefix(line, "/show"):
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
})
p.Stop()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
}
}
fmt.Println()
}
}
// sanitizeFilename removes characters that aren't safe for filenames.
func sanitizeFilename(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
// Remove any character that's not alphanumeric or hyphen
var result strings.Builder
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
fmt.Fprintln(os.Stderr)
}
// printCurrentSettings prints the current image generation settings.
func printCurrentSettings(opts ImageGenOptions) {
fmt.Fprintf(os.Stderr, "Current settings:\n")
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
if opts.NegativePrompt != "" {
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
}
fmt.Fprintln(os.Stderr)
}
// handleSetCommand handles /set commands to change options.
func handleSetCommand(args string, opts *ImageGenOptions) error {
parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 {
return fmt.Errorf("usage: /set <option> <value>")
}
key := strings.ToLower(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "width", "w":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("width must be a positive integer")
}
opts.Width = v
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
case "height", "h":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("height must be a positive integer")
}
opts.Height = v
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
case "steps", "s":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("steps must be a positive integer")
}
opts.Steps = v
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
case "seed":
v, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("seed must be an integer")
}
opts.Seed = v
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
case "negative", "neg", "n":
opts.NegativePrompt = value
if value == "" {
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
} else {
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
}
default:
return fmt.Errorf("unknown option: %s (try /help)", key)
}
return nil
}
// displayImageInTerminal attempts to render an image inline in the terminal.
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
// Returns true if the image was displayed, false otherwise.
func displayImageInTerminal(imagePath string) bool {
// Check if terminal supports inline images
termProgram := os.Getenv("TERM_PROGRAM")
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
weztermPane := os.Getenv("WEZTERM_PANE")
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
// Read the image file
data, err := os.ReadFile(imagePath)
if err != nil {
return false
}
encoded := base64.StdEncoding.EncodeToString(data)
switch {
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
// iTerm2/WezTerm inline image protocol
// ESC ] 1337 ; File = [arguments] : base64 BEL
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
return true
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
// Kitty graphics protocol (also used by Ghostty)
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[i:end]
if i == 0 {
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
more := 1
if end >= len(encoded) {
more = 0
}
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
} else if end >= len(encoded) {
// Last chunk
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
} else {
// Middle chunk
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
}
}
fmt.Println()
return true
default:
return false
}
}

View File

@@ -1,130 +0,0 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -1,35 +0,0 @@
# MLX Engine
Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
## Image Generation
```bash
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
```
Options:
- `-width`, `-height` - image dimensions (default 1024x1024)
- `-steps` - denoising steps (default 9)
- `-seed` - random seed (default 42)

View File

@@ -1,359 +0,0 @@
//go:build mlx
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}

View File

@@ -1,89 +0,0 @@
//go:build mlx
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -1,286 +0,0 @@
//go:build mlx
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}

View File

@@ -1,49 +0,0 @@
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}

View File

@@ -1,183 +0,0 @@
package imagegen
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
continue
}
// Find all safetensors files in this component
entries, err := os.ReadDir(componentDir)
if err != nil {
return fmt.Errorf("failed to read %s: %w", component, err)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(componentDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, layer)
}
extractor.Close()
}
}
// Import config files
configFiles := []string{
"model_index.json",
"text_encoder/config.json",
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
}
for _, cfgPath := range configFiles {
fullPath := filepath.Join(modelDir, cfgPath)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
continue
}
fn(fmt.Sprintf("importing config %s", cfgPath))
var r io.Reader
// For model_index.json, normalize to Ollama format
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
}
// Rename _class_name to architecture, remove diffusers-specific fields
if className, ok := cfg["_class_name"]; ok {
cfg["architecture"] = className
delete(cfg, "_class_name")
}
delete(cfg, "_diffusers_version")
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
}
r = bytes.NewReader(data)
} else {
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
defer f.Close()
r = f
}
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use model_index.json as the config layer
if cfgPath == "model_index.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("model_index.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

View File

@@ -1,107 +0,0 @@
//go:build mlx
package imagegen
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SaveImage saves an MLX array as a PNG image file.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func SaveImage(arr *mlx.Array, path string) error {
img, err := ArrayToImage(arr)
if err != nil {
return err
}
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func EncodeImageBase64(arr *mlx.Array) (string, error) {
img, err := ArrayToImage(arr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// ArrayToImage converts an MLX array to a Go image.RGBA.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -1,177 +0,0 @@
package imagegen
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
)
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"` // Path-style name: "component/tensor" or "path/to/config.json"
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ModelManifest holds a parsed manifest with helper methods.
type ModelManifest struct {
Manifest *Manifest
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
}
// DefaultManifestDir returns the default manifest storage directory.
func DefaultManifestDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
}
// LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("read manifest: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, fmt.Errorf("parse manifest: %w", err)
}
return &ModelManifest{
Manifest: &manifest,
BlobDir: DefaultBlobDir(),
}, nil
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
// Parse model name into components
// Default: registry.ollama.ai/library/<name>/<tag>
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
// Handle explicit tag
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
// Handle full path like "host/namespace/name"
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest.
func (m *ModelManifest) BlobPath(digest string) string {
// Convert "sha256:abc123" to "sha256-abc123"
blobName := strings.Replace(digest, ":", "-", 1)
return filepath.Join(m.BlobDir, blobName)
}
// GetTensorLayers returns all tensor layers for a given component.
// Component should be "text_encoder", "transformer", or "vae".
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
layers = append(layers, layer)
}
}
return layers
}
// GetConfigLayer returns the config layer for a given path.
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
return &layer
}
}
return nil
}
// ReadConfig reads and returns the content of a config file.
func (m *ModelManifest) ReadConfig(configPath string) ([]byte, error) {
layer := m.GetConfigLayer(configPath)
if layer == nil {
return nil, fmt.Errorf("config %q not found in manifest", configPath)
}
blobPath := m.BlobPath(layer.Digest)
return os.ReadFile(blobPath)
}
// ReadConfigJSON reads and unmarshals a config file.
func (m *ModelManifest) ReadConfigJSON(configPath string, v any) error {
data, err := m.ReadConfig(configPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// OpenBlob opens a blob for reading.
func (m *ModelManifest) OpenBlob(digest string) (io.ReadCloser, error) {
return os.Open(m.BlobPath(digest))
}
// HasTensorLayers returns true if the manifest has any tensor layers.
func (m *ModelManifest) HasTensorLayers() bool {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
return true
}
}
return false
}

View File

@@ -1,102 +0,0 @@
// Package imagegen provides experimental image generation capabilities for Ollama.
//
// This package is in x/ because the tensor model storage format is under development.
// The goal is to integrate these capabilities into the main Ollama packages once
// the format is stable.
//
// TODO (jmorganca): Integrate into main packages when stable:
// - CLI commands → cmd/
// - API endpoints → api/
// - Model creation → server/
package imagegen
import (
"encoding/json"
"fmt"
"runtime"
)
// GB is a convenience constant for gigabytes.
const GB = 1024 * 1024 * 1024
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
switch runtime.GOOS {
case "darwin":
// macOS: Metal is supported via MLX
if runtime.GOARCH != "arm64" {
return fmt.Errorf("image generation on macOS requires Apple Silicon (arm64), got %s", runtime.GOARCH)
}
return nil
case "linux", "windows":
// Linux/Windows: CUDA support (requires mlx or cuda build)
// The actual backend availability is checked at runtime
return nil
default:
return fmt.Errorf("image generation is not supported on %s", runtime.GOOS)
}
}
// CheckMemoryRequirements validates that there's enough memory for image generation.
// Returns nil if memory is sufficient, or an error if not.
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
required := EstimateVRAM(modelName)
if availableMemory < required {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
required/GB, availableMemory/GB)
}
return nil
}
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
return modelName
}
return ""
}
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName)
if err != nil {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
return estimate
}
return 21 * GB
}
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

View File

@@ -1,110 +0,0 @@
package imagegen
import (
"runtime"
"testing"
)
func TestCheckPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
if err == nil {
t.Error("Expected error on darwin/non-arm64")
}
}
case "linux", "windows":
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
if err == nil {
t.Errorf("Expected error on unsupported platform %s", runtime.GOOS)
}
}
}
func TestCheckMemoryRequirements(t *testing.T) {
tests := []struct {
name string
availableMemory uint64
wantErr bool
}{
{
name: "sufficient memory",
availableMemory: 32 * GB,
wantErr: false,
},
{
name: "exactly enough memory",
availableMemory: 21 * GB,
wantErr: false,
},
{
name: "insufficient memory",
availableMemory: 16 * GB,
wantErr: true,
},
{
name: "zero memory",
availableMemory: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a non-existent model name which will default to 21GB estimate
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
if (err != nil) != tt.wantErr {
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB,
"QwenImagePipeline": 80 * GB,
}
for name, expectedVRAM := range expected {
if actual, ok := modelVRAMEstimates[name]; !ok {
t.Errorf("Missing VRAM estimate for %s", name)
} else if actual != expectedVRAM {
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
}
}
}
func TestEstimateVRAMDefault(t *testing.T) {
// Non-existent model should return default 21GB
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
if vram != 21*GB {
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
}
}
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")
if result != "" {
t.Errorf("ResolveModelName() = %q, want empty string", result)
}
}

View File

@@ -1,46 +0,0 @@
# MLX Memory Management
| This package will get consolidated with `x/ml/backend/mlx` in the future.
## Automatic Tracking
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
### API
```go
result := mlx.Matmul(x, w) // arrays automatically tracked
mlx.Eval(result) // free non-kept, eval result (auto-kept)
```
### Key Functions
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
- `array.Free()` - mark array for cleanup on next Eval
### Loop Pattern
```go
for step := 0; step < maxTokens; step++ {
logits := model.Forward(token, caches)
oldToken := token
token = sample(logits)
// Keep cache state across iterations
for _, c := range caches {
mlx.Keep(c.State()...)
}
oldToken.Free() // mark for cleanup
mlx.AsyncEval(token) // frees old, evals new
}
```
### Notes
- `Eval()` and `AsyncEval()` auto-keep their outputs
- `Free()` marks for cleanup - actual free happens during next Eval
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
- Arrays created inside compiled closures are managed by MLX, not tracked

View File

@@ -1,173 +0,0 @@
//go:build mlx
package mlx
/*
#include "mlx/c/mlx.h"
#include <stdlib.h>
// Forward declaration for Go callback
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// Destructor for payload (Go handle)
extern void goClosureDestructor(void* payload);
*/
import "C"
import (
"runtime/cgo"
"sync"
"unsafe"
)
// inClosureCallback is set to true during closure callback execution.
var inClosureCallback bool
var closureCallbackMu sync.Mutex
// InClosureCallback returns true if we're currently executing inside a closure callback.
func InClosureCallback() bool {
closureCallbackMu.Lock()
defer closureCallbackMu.Unlock()
return inClosureCallback
}
// CompiledFunc is a compiled MLX function that can be called efficiently.
// All intermediate arrays during execution stay inside MLX - only inputs
// and outputs cross the Go boundary.
type CompiledFunc struct {
closure C.mlx_closure
compiled C.mlx_closure
}
// ClosureFunc is the signature for functions that can be compiled.
// It takes a slice of input arrays and returns a slice of output arrays.
type ClosureFunc func(inputs []*Array) []*Array
// Compile compiles a Go function into an optimized MLX closure.
// The function is traced once during compilation, then subsequent calls
// run the optimized graph without creating Go intermediate arrays.
//
// Example:
//
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
// a, b := inputs[0], inputs[1]
// c := mlx.Add(a, b)
// d := mlx.Mul(c, c)
// return []*mlx.Array{d}
// })
// defer compiled.Free()
//
// result := compiled.Call(x, y)[0]
func Compile(fn ClosureFunc) *CompiledFunc {
return CompileShapeless(fn, false)
}
// CompileShapeless compiles with optional shapeless mode.
// If shapeless=true, the function works for any input shape after tracing.
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
// Create a cgo.Handle to prevent the Go function from being GC'd
handle := cgo.NewHandle(fn)
// Create the closure from the Go callback
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureCallback),
unsafe.Pointer(handle),
(*[0]byte)(C.goClosureDestructor),
)
// Compile the closure
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &CompiledFunc{
closure: closure,
compiled: compiled,
}
}
// Call invokes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
// Pack inputs into vector
inputVec := C.mlx_vector_array_new()
for _, arr := range inputs {
C.mlx_vector_array_append_value(inputVec, arr.c)
}
// Apply compiled closure
outputVec := C.mlx_vector_array_new()
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
C.mlx_vector_array_free(inputVec)
// Unpack outputs
numOutputs := int(C.mlx_vector_array_size(outputVec))
outputs := make([]*Array, numOutputs)
for i := 0; i < numOutputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
outputs[i] = newArray(arr)
}
C.mlx_vector_array_free(outputVec)
return outputs
}
// CallEval invokes the compiled function and evaluates the results.
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
outputs := cf.Call(inputs...)
Eval(outputs...)
return outputs
}
// Free releases the compiled function resources.
func (cf *CompiledFunc) Free() {
C.mlx_closure_free(cf.compiled)
C.mlx_closure_free(cf.closure)
}
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
func borrowArray(array C.mlx_array) *Array {
return &Array{c: array}
}
//export goClosureCallback
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
// Set flag to disable AddCleanup during callback
closureCallbackMu.Lock()
inClosureCallback = true
closureCallbackMu.Unlock()
defer func() {
closureCallbackMu.Lock()
inClosureCallback = false
closureCallbackMu.Unlock()
}()
// Recover the Go function from the handle
handle := cgo.Handle(payload)
fn := handle.Value().(ClosureFunc)
// Convert input vector to Go slice - use borrowArray since MLX owns these
numInputs := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, numInputs)
for i := 0; i < numInputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, input, C.size_t(i))
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
}
// Call the Go function
outputs := fn(inputs)
// Build output vector
*res = C.mlx_vector_array_new()
for _, arr := range outputs {
C.mlx_vector_array_append_value(*res, arr.c)
}
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
handle := cgo.Handle(payload)
handle.Delete()
}

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,614 +0,0 @@
//go:build mlx
package gemma3
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextConfig holds configuration for the text model
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
// Computed fields
Scale float32 `json:"-"`
}
// TextModel is the Gemma 3 text-only model
type TextModel struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*DecoderLayer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
NormScaled *mlx.Array `weight:"-"`
tok *tokenizer.Tokenizer
*TextConfig
}
// DecoderLayer is a single transformer block
type DecoderLayer struct {
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
Attention *Attention
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
MLP *MLP
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
InputNormScaled *mlx.Array `weight:"-"`
PostAttnNormScaled *mlx.Array `weight:"-"`
PreFFNormScaled *mlx.Array `weight:"-"`
PostFFNormScaled *mlx.Array `weight:"-"`
// Whether this layer uses sliding window attention
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
QNormScaled *mlx.Array `weight:"-"`
KNormScaled *mlx.Array `weight:"-"`
}
// MLP is the feed-forward network with GELU activation
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
// LoadText loads the text-only Gemma 3 model
func LoadText(modelPath string) (*TextModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute scale
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Set defaults if not specified
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &TextModel{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
// Initialize layer metadata
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for output
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
precomputeGemmaScaledWeights(m)
return m, nil
}
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
// This avoids creating temporary arrays on every forward pass
func precomputeGemmaScaledWeights(m *TextModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
// Eval all the precomputed weights
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Eval(scaled...)
}
// isLayerSliding determines if a layer uses sliding window attention
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false // No sliding window
}
// Layer is full attention if (layerIdx + 1) % pattern == 0
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
// Get embeddings and scale by sqrt(hidden_size)
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
}
// Final norm and output projection
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
}
// Forward runs a decoder layer
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
// Pre-attention norm (use precomputed scaled weight)
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
// Attention
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
// Post-attention norm and residual
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
// Pre-FFN norm
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
// MLP
mlpOut := l.MLP.Forward(normed)
// Post-FFN norm and residual
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
// Forward runs attention with Q/K normalization
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization after reshaping (use precomputed scaled weight)
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// Apply RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// Repeat K/V for GQA if needed
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// Attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
var compiledGeluApprox *mlx.CompiledFunc
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
func getCompiledGeluApprox() *mlx.CompiledFunc {
if compiledGeluApprox == nil {
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApproxImpl(inputs[0])}
}, true)
}
return compiledGeluApprox
}
// Forward runs the MLP with GELU approximation (tanh variant)
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApproxImpl(x *mlx.Array) *mlx.Array {
// Constants
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
const coeff = 0.044715
// x^3
x3 := mlx.Mul(mlx.Mul(x, x), x)
// x + 0.044715 * x^3
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := mlx.MulScalar(inner, sqrt2OverPi)
// tanh(...)
tanh := mlx.Tanh(scaled)
// 1 + tanh(...)
onePlusTanh := mlx.AddScalar(tanh, 1.0)
// 0.5 * x * (1 + tanh(...))
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
}
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
// Gemma uses (1 + weight) instead of weight
scaledWeight := mlx.AddScalar(weight, 1.0)
return mlx.RMSNorm(x, scaledWeight, eps)
}
// Interface methods
func (m *TextModel) NumLayers() int { return len(m.Layers) }
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// FormatPrompt applies the Gemma 3 chat template to a prompt
func (m *TextModel) FormatPrompt(prompt string) string {
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
// Use rotating cache for sliding window layers
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
// Use regular cache for global attention layers
caches[i] = cache.NewKVCache()
}
}
return caches
}
// Config holds config for the full multimodal model
type Config struct {
TextConfig TextConfig `json:"text_config"`
VisionConfig VisionConfig `json:"vision_config"`
// Image token config (from config.json)
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
}
// Model is the full Gemma 3 multimodal model
type Model struct {
VisionTower *VisionTower `weight:"vision_tower"`
Projector *MultiModalProjector `weight:"multi_modal_projector"`
TextModel *TextModel `weight:"language_model"`
Config *Config
tok *tokenizer.Tokenizer
}
// Load loads the full multimodal Gemma 3 model
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Set defaults for text config (multimodal config often has incomplete text_config)
// These defaults match transformers.Gemma3TextConfig defaults
tc := &cfg.TextConfig
if tc.HeadDim == 0 {
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
}
if tc.NumAttentionHeads == 0 {
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
tc.NumAttentionHeads = 8
}
if tc.NumKeyValueHeads == 0 {
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
tc.NumKeyValueHeads = 4
}
if tc.VocabSize == 0 {
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
}
if tc.RopeTheta == 0 {
tc.RopeTheta = 1000000
}
if tc.RopeLocalBaseFreq == 0 {
tc.RopeLocalBaseFreq = 10000
}
if tc.RMSNormEps == 0 {
tc.RMSNormEps = 1e-6
}
if tc.SlidingWindowPattern == 0 {
tc.SlidingWindowPattern = 6
}
if tc.MaxPositionEmbeddings == 0 {
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
}
// Compute text model scale
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
// Set defaults for image token config
if cfg.BOITokenIndex == 0 {
cfg.BOITokenIndex = 255999 // <start_of_image>
}
if cfg.EOITokenIndex == 0 {
cfg.EOITokenIndex = 256000 // <end_of_image>
}
if cfg.ImageTokenIndex == 0 {
cfg.ImageTokenIndex = 262144 // <image_soft_token>
}
if cfg.MMTokensPerImage == 0 {
cfg.MMTokensPerImage = 256
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
VisionTower: &VisionTower{
Embeddings: &VisionEmbeddings{},
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
Config: &cfg.VisionConfig,
},
Projector: &MultiModalProjector{},
TextModel: &TextModel{
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
TextConfig: &cfg.TextConfig,
},
Config: &cfg,
tok: tok,
}
// Initialize text layer metadata
for i := range m.TextModel.Layers {
m.TextModel.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
}
}
// Initialize vision encoder layers
for i := range m.VisionTower.Encoder {
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for text output
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
m.TextModel.tok = tok
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeGemmaScaledWeights(m.TextModel)
// Precompute projector's scaled weight
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
mlx.Eval(m.Projector.SoftEmbNormScaled)
return m, nil
}
// Forward runs the text-only forward pass
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.TextModel.Forward(tokens, caches)
}
// ForwardWithImage runs the multimodal forward pass
// tokens: [B, L] input token IDs (with image placeholder tokens)
// image: [B, H, W, C] preprocessed image tensor
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
cfg := m.Config.TextConfig
// Find image token position FIRST before any eval that might free tokens
imageStartPos := int32(-1)
if image != nil && B == 1 {
tokenData := tokens.DataInt32() // This evals tokens
for i, t := range tokenData {
if t == m.Config.ImageTokenIndex {
imageStartPos = int32(i)
break
}
}
}
// Get text embeddings and scale
h := m.TextModel.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
// Process image if provided
if image != nil && imageStartPos >= 0 {
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
visionFeatures := m.VisionTower.Forward(image)
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
// Eval h and imageEmbeds together so neither gets freed
mlx.Eval(h, imageEmbeds)
// Cast imageEmbeds to match text embeddings dtype (bf16)
if imageEmbeds.Dtype() != h.Dtype() {
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
mlx.Eval(imageEmbeds)
}
// Insert image embeddings at the known position
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
}
// Run through text model layers
for i, layer := range m.TextModel.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
}
// Final norm and output projection
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
}
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
// at a known position (to avoid re-scanning tokens after eval)
// textEmbeds: [B, L, hidden_size] text embeddings
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
// startPos: starting position of image tokens in the sequence
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
numImageTokens := imageEmbeds.Shape()[1]
L := textEmbeds.Shape()[1]
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
afterStart := startPos + numImageTokens
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
// Concatenate: before + imageEmbeds + after along axis 1
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
}
// Interface methods for Model
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
// FormatPrompt applies the Gemma 3 multimodal chat template
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
func (m *Model) FormatPromptWithImage(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
// Input tokens containing boi_token (255999) are expanded to:
// boi_token + 256 * image_token + eoi_token
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
for _, t := range tokens {
if t == m.Config.BOITokenIndex {
// Expand: boi + 256 * image_token + eoi
result = append(result, m.Config.BOITokenIndex)
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
result = append(result, m.Config.ImageTokenIndex)
}
result = append(result, m.Config.EOITokenIndex)
} else {
result = append(result, t)
}
}
return result
}

View File

@@ -1,58 +0,0 @@
//go:build mlx
package gemma3
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
)
// ProcessImage loads and preprocesses an image for the vision tower
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return ProcessImageData(img, imageSize)
}
// ProcessImageData preprocesses an image.Image for the vision tower
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
// Resize to target size using bilinear interpolation
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
// Convert to float32 array [H, W, C] and normalize
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
data := make([]float32, imageSize*imageSize*3)
idx := 0
for y := int32(0); y < imageSize; y++ {
for x := int32(0); x < imageSize; x++ {
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
// RGBA returns 16-bit values, convert to 8-bit
data[idx] = float32(r>>8)/127.5 - 1.0
data[idx+1] = float32(g>>8)/127.5 - 1.0
data[idx+2] = float32(b>>8)/127.5 - 1.0
idx += 3
}
}
// Create MLX array [1, H, W, C] for NHWC layout
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
mlx.Eval(arr) // Materialize to prevent use-after-free
return arr, nil
}

View File

@@ -1,50 +0,0 @@
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// MultiModalProjector projects vision features to text embedding space
type MultiModalProjector struct {
// mm_input_projection_weight: [vision_hidden, text_hidden]
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
SoftEmbNormScaled *mlx.Array `weight:"-"`
}
// Forward projects vision features to text space
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
B := visionFeatures.Shape()[0]
visionHidden := visionFeatures.Shape()[2]
// Reshape to [B, 64, 64, hidden]
gridSize := int32(64) // sqrt(4096)
pooledSize := int32(16) // 64/4
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
// Average over pooling dimensions (axes 2 and 4)
h = mlx.Mean(h, 4, false)
h = mlx.Mean(h, 2, false)
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
numTokens := pooledSize * pooledSize
h = mlx.Reshape(h, B, numTokens, visionHidden)
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
return mlx.Linear(h, p.InputProjection)
}

View File

@@ -1,138 +0,0 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
HiddenSize int32 `json:"hidden_size"`
ImageSize int32 `json:"image_size"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
PatchSize int32 `json:"patch_size"`
}
// VisionTower is the SigLIP vision encoder
type VisionTower struct {
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
Config *VisionConfig
}
// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
PosEmbed *nn.Embedding `weight:"position_embedding"`
}
// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
Attention *VisionAttention `weight:"self_attn"`
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
MLP *VisionMLP `weight:"mlp"`
}
// VisionAttention implements multi-head self-attention
type VisionAttention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OutProj *nn.Linear `weight:"out_proj"`
}
// VisionMLP is the feed-forward network
type VisionMLP struct {
FC1 *nn.Linear `weight:"fc1"`
FC2 *nn.Linear `weight:"fc2"`
}
// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
h = mlx.Add(h, bias)
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
B := h.Shape()[0]
gridH, gridW := h.Shape()[1], h.Shape()[2]
hidden := h.Shape()[3]
numPatches := gridH * gridW
h = mlx.Reshape(h, B, numPatches, hidden)
// Add position embeddings
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
h = mlx.Add(h, posEmbed)
// Encoder layers
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
for _, layer := range v.Encoder {
h = layer.Forward(h, v.Config, scale)
}
// Final layer norm
h = v.PostLayerNorm.Forward(h)
return h
}
// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
// Pre-norm attention
h := l.LayerNorm1.Forward(x)
h = l.Attention.Forward(h, cfg, scale)
x = mlx.Add(x, h)
// Pre-norm MLP
h = l.LayerNorm2.Forward(x)
h = l.MLP.Forward(h)
return mlx.Add(x, h)
}
// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
B, L := x.Shape()[0], x.Shape()[1]
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
// Scaled dot-product attention (no causal mask for vision)
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
return a.OutProj.Forward(out)
}
// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
h := mlx.GELU(m.FC1.Forward(x))
return m.FC2.Forward(h)
}

View File

@@ -1,487 +0,0 @@
//go:build mlx
package gpt_oss
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds YaRN or other RoPE scaling configuration
type RopeScaling struct {
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
}
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
SlidingWindow int32 `json:"sliding_window"`
NumLocalExperts int32 `json:"num_local_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
LayerTypes []string `json:"layer_types"`
SwiGLULimit float32 `json:"swiglu_limit"`
RopeScaling *RopeScaling `json:"rope_scaling"`
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
YarnFreqs *mlx.Array // computed
YarnMscale float32
}
// swiGLU applies the GPT-OSS custom SwiGLU activation.
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
// with clipping: gate to [None, limit], up to [-limit, limit]
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
// Clip gate to [None, limit]
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
// Clip up to [-limit, limit]
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
// glu_scaled = alpha * gate_clipped
gluScaled := mlx.MulScalar(gateClipped, alpha)
// sig = sigmoid(glu_scaled)
sig := mlx.Sigmoid(gluScaled)
// out_glu = gate_clipped * sig
outGlu := mlx.Mul(gateClipped, sig)
// result = out_glu * (up_clipped + 1)
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
}
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
var compiledSwiGLU *mlx.CompiledFunc
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
const alpha float32 = 1.702
const limit float32 = 7.0
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
}, true) // shapeless=true so it works for any input size
}
return compiledSwiGLU
}
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
// Based on mlx-lm's YarnRoPE implementation
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
// yarn_find_correction_dim
yarnFindCorrectionDim := func(numRotations float64) float64 {
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
}
// yarn_find_correction_range
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
if low < 0 {
low = 0
}
if high > int(dims)-1 {
high = int(dims) - 1
}
// yarn_get_mscale
yarnGetMscale := func(scale, mscale float64) float64 {
if scale <= 1 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
// Compute frequencies
// freq_extra = base ** (arange(0, dims, 2) / dims)
// freq_inter = scaling_factor * freq_extra
halfDims := dims / 2
freqData := make([]float32, halfDims)
for i := int32(0); i < halfDims; i++ {
exp := float64(2*i) / float64(dims)
freqExtra := math.Pow(float64(base), exp)
freqInter := float64(scalingFactor) * freqExtra
// linear ramp mask
var freqMask float64
if low == high {
freqMask = 0.0
} else {
t := (float64(i) - float64(low)) / float64(high-low)
if t < 0 {
t = 0
}
if t > 1 {
t = 1
}
freqMask = 1.0 - t
}
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
}
return mlx.NewArray(freqData, []int32{halfDims}), mscale
}
// initYarn initializes YaRN RoPE if configured
func (a *Attention) initYarn(cfg *Config) {
a.YarnMscale = 1.0
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
cfg.HeadDim,
cfg.RopeTheta,
cfg.RopeScaling.Factor,
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
cfg.RopeScaling.BetaFast,
cfg.RopeScaling.BetaSlow,
)
}
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
offset := 0
if c != nil {
offset = c.Offset()
}
if a.YarnFreqs != nil {
if a.YarnMscale != 1.0 {
q = mlx.MulScalar(q, a.YarnMscale)
}
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
} else {
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v, int(L))
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// CreateSlidingWindowMask creates a causal mask with sliding window
// Mirrors mlx-lm's create_causal_mask with window_size
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
// Build mask aligned to actual cache length (may be rotated)
// rinds covers existing keys: [keyStart, keyStart+keyLen)
// linds covers new queries: [queryStart, queryStart+seqLen)
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
return mlx.LogicalAnd(causalMask, windowMask)
}
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
type MoE struct {
Router *nn.Linear `weight:"mlp.router"`
TopK int32
HiddenSize int32
GroupSize int
Bits int
// Expert weights (loaded manually via sanitizeExpertWeights)
GateBlocks, GateScales, GateBias *mlx.Array
UpBlocks, UpScales, UpBias *mlx.Array
DownBlocks, DownScales, DownBias *mlx.Array
}
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
logits := moe.Router.Forward(x)
neg := mlx.Neg(logits)
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
weights := mlx.Softmax(topKVal, -1)
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
doSort := B*L >= 64
var invOrder *mlx.Array
sorted := false
n := B * L * moe.TopK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
sorted = true
}
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.GateBias != nil {
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
}
if moe.UpBias != nil {
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
}
hidden := getCompiledSwiGLU().Call(gate, up)[0]
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.DownBias != nil {
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
}
type Block struct {
Attention *Attention
MLP *MoE
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
LayerType string // "sliding_attention" or "full_attention"
}
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
x := m.EmbedTokens.Forward(tokens)
// Find representative cache indices for sliding window attention
var swaIdx int = -1
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" {
swaIdx = i
break
}
}
// Create masks once at model level
var fullMask, swaMask *mlx.Array
var fullMaskMode, swaMaskMode string
if L > 1 {
fullMaskMode = "causal"
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
c := caches[swaIdx]
offset := c.Offset()
windowSize := int(m.SlidingWindow)
cacheLen := min(int(L), windowSize)
if offset > 0 {
cacheLen = min(c.Len()+int(L), windowSize)
}
if int(L) > windowSize {
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
} else {
swaMaskMode = "causal"
}
} else {
swaMaskMode = "causal"
}
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
mask, maskMode := fullMask, fullMaskMode
if layer.LayerType == "sliding_attention" {
mask, maskMode = swaMask, swaMaskMode
}
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
}
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
}
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
if gateUpBlocks != nil {
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
s := gub.Shape()
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpScales != nil {
s := gateUpScales.Shape()
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpBias != nil {
s := gateUpBias.Shape()
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
}
if downBlocks != nil {
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
}
return moe
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load simple weights via struct tags
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers with custom MoE handling
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
layer := &Block{}
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d: %w", i, err)
}
// Initialize attention YaRN
layer.Attention.initYarn(&cfg)
// Load MoE with weight sanitization
moe := sanitizeExpertWeights(weights, prefix)
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
moe.TopK = cfg.NumExpertsPerTok
moe.HiddenSize = cfg.HiddenSize
layer.MLP = moe
// Set layer type
layer.LayerType = "full_attention"
if int(i) < len(cfg.LayerTypes) {
layer.LayerType = cfg.LayerTypes[i]
}
m.Layers[i] = layer
}
// Release safetensors BEFORE eval - lazy arrays have captured data,
// this reduces peak memory by freeing mmap during materialization
weights.ReleaseAll()
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
func (m *Model) MaxContextLength() int32 {
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
return m.RopeScaling.OriginalMaxPositionEmbeddings
}
return 131072
}

View File

@@ -1,152 +0,0 @@
//go:build mlx
package llama
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Layer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"lm_head,optional"`
tok *tokenizer.Tokenizer
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
}
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Config)
}
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k, v = c.Update(k, v, int(L))
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}
// Interface methods
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}

View File

@@ -1,66 +0,0 @@
//go:build mlx
package qwen_image
import (
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,350 +0,0 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,218 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,135 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

Some files were not shown because too many files have changed in this diff Show More