mirror of
https://github.com/ollama/ollama.git
synced 2026-01-24 07:20:57 -05:00
Compare commits
27 Commits
parth/agen
...
imagegen-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb1a5617b6 | ||
|
|
0d3648c1be | ||
|
|
02a2401596 | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 | ||
|
|
7cc2a653f2 | ||
|
|
2584940016 | ||
|
|
c6d4c0c7f2 | ||
|
|
1ef4241727 | ||
|
|
68fafd3002 | ||
|
|
2b2cda7a2b | ||
|
|
3cfe9fe146 | ||
|
|
a23b559b4c |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
if(APPLE)
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
|
||||
@@ -161,10 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
# TODO wire up the actual MLX engine here instead of building the main binary...
|
||||
RUN mkdir -p dist/bin
|
||||
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -205,7 +204,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
||||
42
README.md
42
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -421,7 +453,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +668,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +677,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
778
anthropic/anthropic.go
Normal file
778
anthropic/anthropic.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
953
anthropic/anthropic_test.go
Normal file
953
anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
|
||||
33
cmd/cmd.go
33
cmd/cmd.go
@@ -46,6 +46,8 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -96,6 +98,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
@@ -457,6 +464,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -518,9 +526,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("yolo")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -550,7 +567,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -656,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -1765,7 +1786,11 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
406
docs/api/anthropic-compatibility.mdx
Normal file
406
docs/api/anthropic-compatibility.mdx
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -32,7 +32,9 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy"]
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -52,7 +54,9 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": ["curl"]
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -97,6 +101,7 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -139,7 +144,8 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility"
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
69
docs/integrations/claude-code.mdx
Normal file
69
docs/integrations/claude-code.mdx
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Linux
|
||||
title: "Linux"
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +40,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
@@ -1,3 +0,0 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
152
middleware/anthropic.go
Normal file
152
middleware/anthropic.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
607
middleware/anthropic_test.go
Normal file
607
middleware/anthropic_test.go
Normal file
@@ -0,0 +1,607 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err anthropic.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
stream := true
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "basic message",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with system prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are helpful.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop": []string{"\n", "END"},
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &stream,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testProps(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tool result",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with thinking enabled",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "model is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing max_tokens error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "max_tokens is required and must be positive",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing messages error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "messages is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_use missing id error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "name": "test"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "tool_use block missing required 'id' field",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Type != "" {
|
||||
// Expect error
|
||||
if resp.Code == http.StatusOK {
|
||||
t.Fatalf("expected error response, got 200 OK")
|
||||
}
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
if errResp.Type != tc.err.Type {
|
||||
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tc.err.Error.Type {
|
||||
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tc.err.Error.Message {
|
||||
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("request was not captured")
|
||||
}
|
||||
|
||||
// Compare relevant fields
|
||||
if capturedRequest.Model != tc.req.Model {
|
||||
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
||||
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
||||
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
||||
if *tc.req.Stream != *capturedRequest.Stream {
|
||||
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.req.Think != nil {
|
||||
if capturedRequest.Think == nil {
|
||||
t.Error("expected Think to be set")
|
||||
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
||||
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("streaming sets correct headers", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Check headers were set
|
||||
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
||||
}
|
||||
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate Ollama response
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result anthropic.MessagesResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
||||
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
||||
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
errorPayload any
|
||||
wantErrorType string
|
||||
wantMessage string
|
||||
}{
|
||||
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
||||
{
|
||||
name: "404 with gin.H error (model not found)",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "400 with gin.H error (bad request)",
|
||||
statusCode: http.StatusBadRequest,
|
||||
errorPayload: gin.H{"error": "model is required"},
|
||||
wantErrorType: "invalid_request_error",
|
||||
wantMessage: "model is required",
|
||||
},
|
||||
{
|
||||
name: "500 with gin.H error (internal error)",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
errorPayload: gin.H{"error": "something went wrong"},
|
||||
wantErrorType: "api_error",
|
||||
wantMessage: "something went wrong",
|
||||
},
|
||||
{
|
||||
name: "404 with api.StatusError",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: api.StatusError{
|
||||
StatusCode: http.StatusNotFound,
|
||||
ErrorMessage: "model not found via StatusError",
|
||||
},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model not found via StatusError",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate what routes.go does - set status and write error JSON
|
||||
data, _ := json.Marshal(tt.errorPayload)
|
||||
c.Writer.WriteHeader(tt.statusCode)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.statusCode {
|
||||
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tt.wantErrorType {
|
||||
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tt.wantMessage {
|
||||
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
33
progress/stepbar.go
Normal file
33
progress/stepbar.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StepBar displays step-based progress (e.g., for image generation steps).
|
||||
type StepBar struct {
|
||||
message string
|
||||
current int
|
||||
total int
|
||||
}
|
||||
|
||||
func NewStepBar(message string, total int) *StepBar {
|
||||
return &StepBar{message: message, total: total}
|
||||
}
|
||||
|
||||
func (s *StepBar) Set(current int) {
|
||||
s.current = current
|
||||
}
|
||||
|
||||
func (s *StepBar) String() string {
|
||||
percent := float64(s.current) / float64(s.total) * 100
|
||||
barWidth := s.total
|
||||
empty := barWidth - s.current
|
||||
|
||||
// "Generating 0% ▕ ▏ 0/9"
|
||||
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
|
||||
s.message, percent,
|
||||
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
|
||||
s.current, s.total)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -11,12 +12,19 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
if args[0] == "--ollama-engine" {
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if newRunner {
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
@@ -73,7 +73,7 @@ _build_darwin() {
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
@@ -82,19 +82,19 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,23 +154,25 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -178,11 +180,11 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
@@ -206,7 +208,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
|
||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||
#
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "ERROR: No directory specified" >&2
|
||||
echo "Usage: $0 <base_directory>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
base_dir="$1"
|
||||
|
||||
if [ ! -d "${base_dir}" ]; then
|
||||
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Deduplication complete"
|
||||
183
server/images.go
183
server/images.go
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,6 +74,11 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
@@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestJSON, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
@@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
@@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
@@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
@@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
@@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pulling model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
From: layer.From,
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pushing model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
SrcDir: srcDir,
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
|
||||
@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with image generation capability via config",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -13,9 +13,14 @@ type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
}
|
||||
|
||||
const (
|
||||
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
|
||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,8 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1093,6 +1124,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: m.Config.FileType,
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1175,6 +1215,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1544,6 +1588,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -2022,8 +2072,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
if _, ok := c.Get("relax_thinking"); ok {
|
||||
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||
req.Think = nil
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -543,6 +552,48 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
|
||||
if len(allGpus) == 0 {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return -
|
||||
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Simulate an image gen runner already loaded
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
|
||||
modelPath: "/fake/image/model",
|
||||
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
refCount: 0, // idle
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/image/model"] = imageGenRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify the image gen runner is loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// findRunnerToUnload should find the idle image gen runner
|
||||
runner := s.findRunnerToUnload()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package model
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
24
x/README.md
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -33,7 +33,7 @@ type ApprovalResult struct {
|
||||
// Option labels for the selector (numbered for quick selection)
|
||||
var optionLabels = []string{
|
||||
"1. Execute once",
|
||||
"2. Always allow",
|
||||
"2. Allow for this session",
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
"web_fetch": "Web Fetch",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
@@ -494,16 +495,32 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// This prevents buffered input from causing double-press issues
|
||||
flushStdin(fd)
|
||||
|
||||
// Check if bash command targets paths outside cwd
|
||||
isWarning := false
|
||||
var warningMsg string
|
||||
var allowlistInfo string
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
isWarning = isCommandOutsideCwd(cmd)
|
||||
if isCommandOutsideCwd(cmd) {
|
||||
isWarning = true
|
||||
warningMsg = "command targets paths outside project"
|
||||
}
|
||||
if prefix := extractBashPrefix(cmd); prefix != "" {
|
||||
colonIdx := strings.Index(prefix, ":")
|
||||
if colonIdx != -1 {
|
||||
cmdName := prefix[:colonIdx]
|
||||
dirPath := prefix[colonIdx+1:]
|
||||
if dirPath != "./" {
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
|
||||
} else {
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run interactive selector
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
|
||||
if err != nil {
|
||||
term.Restore(fd, oldState)
|
||||
return ApprovalResult{Decision: ApprovalDeny}, err
|
||||
@@ -549,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web fetch, show URL and internet notice
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
@@ -567,24 +594,28 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
|
||||
// selectorState holds the state for the interactive selector
|
||||
type selectorState struct {
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command targets paths outside cwd (red box)
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command has warning
|
||||
warningMessage string // dynamic warning message to display
|
||||
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
|
||||
}
|
||||
|
||||
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
|
||||
state := &selectorState{
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
warningMessage: warningMessage,
|
||||
allowlistInfo: allowlistInfo,
|
||||
}
|
||||
|
||||
// Get terminal size
|
||||
@@ -771,7 +802,11 @@ func renderSelectorBox(state *selectorState) {
|
||||
|
||||
// Draw warning line if needed
|
||||
if state.isWarning {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
||||
if state.warningMessage != "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
|
||||
}
|
||||
|
||||
@@ -783,21 +818,27 @@ func renderSelectorBox(state *selectorState) {
|
||||
// Blank line separator
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
|
||||
// Draw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option with input
|
||||
if i == 2 {
|
||||
denyLabel := "3. Deny: "
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -825,21 +866,27 @@ func updateSelectorOptions(state *selectorState) {
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
if i == 2 {
|
||||
denyLabel := "3. Deny: "
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -868,6 +915,9 @@ func updateReasonInput(state *selectorState) {
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
@@ -901,7 +951,7 @@ func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult,
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "choice: ")
|
||||
|
||||
var input string
|
||||
@@ -950,11 +1000,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
label = "approved"
|
||||
label = "Approved"
|
||||
case ApprovalAlways:
|
||||
label = "always allowed"
|
||||
label = "Always allowed"
|
||||
case ApprovalDeny:
|
||||
label = "denied"
|
||||
label = "Denied"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
@@ -978,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
// Truncate long URLs
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
|
||||
385
x/cmd/run.go
385
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -24,6 +25,14 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// MultilineState tracks the state of multiline input
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
@@ -130,6 +139,7 @@ type RunOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
Verbose bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
@@ -178,6 +188,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
var latest api.ChatResponse
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -187,6 +198,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
@@ -364,10 +376,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
// TODO(parthsareen): re-enable with tighter scoped allowlist
|
||||
// if agent.IsAutoAllowed(cmd) {
|
||||
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
// skipApproval = true
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,6 +495,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if opts.Verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
@@ -633,7 +650,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
// If enableWebsearch is true, the web search tool is registered.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
@@ -658,14 +676,28 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
// Register web search and web fetch tools if enabled via flag
|
||||
if enableWebsearch {
|
||||
toolRegistry.RegisterWebSearch()
|
||||
toolRegistry.RegisterWebFetch()
|
||||
}
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
@@ -673,6 +705,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
var multiline MultilineState = MultilineNone
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -684,13 +719,39 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
multiline = MultilineNone
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
@@ -703,6 +764,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
@@ -712,6 +777,303 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
wordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
wordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
think = &thinkValue
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think = &api.ThinkValue{Value: false}
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||
continue
|
||||
}
|
||||
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: modelName,
|
||||
Options: options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, " Model\n")
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||
if resp.Details.Family != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||
}
|
||||
if resp.Details.ParameterSize != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||
}
|
||||
if resp.Details.QuantizationLevel != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||
}
|
||||
if len(resp.Capabilities) > 0 {
|
||||
caps := make([]string, len(resp.Capabilities))
|
||||
for i, c := range resp.Capabilities {
|
||||
caps[i] = string(c)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
if len(options) > 0 {
|
||||
fmt.Println("\nUser defined parameters:")
|
||||
for k, v := range options {
|
||||
fmt.Printf(" %-30s %v\n", k, v)
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case system != "":
|
||||
fmt.Println(system + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModelName := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||
|
||||
// Create progress spinner
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// Get client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists and get its info
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For cloud models, no need to preload
|
||||
if info.RemoteHost == "" {
|
||||
// Preload the model by sending an empty generate request
|
||||
req := &api.GenerateRequest{
|
||||
Model: newModelName,
|
||||
Think: think,
|
||||
}
|
||||
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("error loading model: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
modelName = newModelName
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /save <modelname>")
|
||||
continue
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: args[1],
|
||||
From: modelName,
|
||||
Parameters: options,
|
||||
Messages: messages,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
@@ -719,14 +1081,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
@@ -734,6 +1098,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
@@ -1,61 +1,250 @@
|
||||
# imagegen
|
||||
# Image Generation in Ollama (Experimental)
|
||||
|
||||
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
|
||||
in `CMakeLists.txt` and rebuild.
|
||||
Generate images from text prompts using local AI models.
|
||||
|
||||
### 1. Download a Model
|
||||
|
||||
Download Llama 3.1 8B (or any compatible model) in safetensors format:
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
mkdir -p ./weights
|
||||
|
||||
# Example using huggingface-cli
|
||||
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
|
||||
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
|
||||
# Run with a prompt
|
||||
ollama run z-image "a sunset over mountains"
|
||||
Generating: step 30/30
|
||||
Image saved to: /tmp/ollama-image-1704067200.png
|
||||
```
|
||||
|
||||
### 2. Run Inference
|
||||
On macOS, the generated image will automatically open in Preview.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | VRAM Required | Notes |
|
||||
|-------|---------------|-------|
|
||||
| z-image | ~12GB | Based on Flux architecture |
|
||||
|
||||
## CLI Usage
|
||||
|
||||
```bash
|
||||
# Build
|
||||
go build ./cmd/engine
|
||||
# Generate an image
|
||||
ollama run z-image "a cat playing piano"
|
||||
|
||||
# Text generation
|
||||
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
|
||||
# Check if model is running
|
||||
ollama ps
|
||||
|
||||
# Qwen-Image 2512 (text-to-image)
|
||||
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
|
||||
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
|
||||
|
||||
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
|
||||
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
|
||||
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
|
||||
-steps 8 -seed 42 -output edited.png
|
||||
# Stop the model
|
||||
ollama stop z-image
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
## API
|
||||
|
||||
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
|
||||
### OpenAI-Compatible Endpoint
|
||||
|
||||
Instead, arrays are automatically tracked and freed on `Eval()`:
|
||||
|
||||
```go
|
||||
// All arrays are automatically tracked when created
|
||||
x := mlx.Add(a, b)
|
||||
y := mlx.Matmul(x, w)
|
||||
|
||||
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
|
||||
mlx.Eval(y)
|
||||
|
||||
// After copying to CPU, free the array
|
||||
data := y.Data()
|
||||
y.Free()
|
||||
```bash
|
||||
POST /v1/images/generations
|
||||
```
|
||||
|
||||
Key points:
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"model": "z-image",
|
||||
"prompt": "a sunset over mountains",
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json"
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"created": 1704067200,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgo..."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Example: cURL
|
||||
|
||||
```bash
|
||||
curl http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "z-image",
|
||||
"prompt": "a white cat",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Save to File
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "z-image",
|
||||
"prompt": "a white cat",
|
||||
"size": "1024x1024"
|
||||
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
|
||||
```
|
||||
|
||||
### Streaming Progress
|
||||
|
||||
Enable streaming to receive progress updates via SSE:
|
||||
|
||||
```bash
|
||||
curl http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
|
||||
```
|
||||
|
||||
Events:
|
||||
```
|
||||
event: progress
|
||||
data: {"step": 1, "total": 30}
|
||||
|
||||
event: progress
|
||||
data: {"step": 2, "total": 30}
|
||||
...
|
||||
|
||||
event: done
|
||||
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| model | string | required | Model name |
|
||||
| prompt | string | required | Text description of image |
|
||||
| size | string | "1024x1024" | Image dimensions (WxH) |
|
||||
| n | int | 1 | Number of images (currently only 1 supported) |
|
||||
| response_format | string | "b64_json" | "b64_json" or "url" |
|
||||
| stream | bool | false | Enable progress streaming |
|
||||
|
||||
## Requirements
|
||||
|
||||
- macOS with Apple Silicon (M1/M2/M3/M4)
|
||||
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
|
||||
- Sufficient VRAM (see model table above)
|
||||
- Ollama built with MLX support
|
||||
|
||||
## Limitations
|
||||
|
||||
- macOS only (uses MLX backend)
|
||||
- Single image per request
|
||||
- Fixed step count (30 steps)
|
||||
- Modelfiles not yet supported (use `ollama create` from model directory)
|
||||
|
||||
---
|
||||
|
||||
# Tensor Model Storage Format
|
||||
|
||||
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
|
||||
|
||||
## Manifest Structure
|
||||
|
||||
The manifest follows the standard ollama format with tensor-specific layer metadata:
|
||||
|
||||
```json
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": { "digest": "sha256:...", "size": 1234 },
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.tensor",
|
||||
"digest": "sha256:25b36eed...",
|
||||
"size": 49807448,
|
||||
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 9728]
|
||||
},
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.json",
|
||||
"digest": "sha256:abc123...",
|
||||
"size": 512,
|
||||
"name": "text_encoder/config.json"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Each tensor layer includes:
|
||||
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
|
||||
- `dtype`: Data type (BF16, F32, etc.)
|
||||
- `shape`: Tensor dimensions
|
||||
|
||||
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
|
||||
|
||||
## Blob Format
|
||||
|
||||
Each tensor blob is a minimal safetensors file:
|
||||
|
||||
```
|
||||
[8 bytes: header size (uint64 LE)]
|
||||
[~80 bytes: JSON header, padded to 8-byte alignment]
|
||||
[N bytes: raw tensor data]
|
||||
```
|
||||
|
||||
Header contains a single tensor named `"data"`:
|
||||
|
||||
```json
|
||||
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
|
||||
```
|
||||
|
||||
## Why Include the Header?
|
||||
|
||||
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
|
||||
|
||||
1. **Uses mmap** - Maps file directly into memory, no copies
|
||||
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
|
||||
3. **No custom code** - Standard MLX API, battle-tested
|
||||
|
||||
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
|
||||
|
||||
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
|
||||
|
||||
## Why Per-Tensor Blobs?
|
||||
|
||||
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
|
||||
|
||||
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
blobs/
|
||||
sha256-25b36eed... <- shared by both models
|
||||
sha256-abc123...
|
||||
manifests/
|
||||
library/model-a/latest <- references sha256-25b36eed
|
||||
library/model-b/latest <- references sha256-25b36eed
|
||||
```
|
||||
|
||||
## Import Flow
|
||||
|
||||
```
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image
|
||||
|
||||
1. Scan component directories (text_encoder/, transformer/, vae/)
|
||||
2. For each .safetensors file:
|
||||
- Extract individual tensors
|
||||
- Wrap each in minimal safetensors format (88B header + data)
|
||||
- Write to blob store (SHA256 content-addressed)
|
||||
- Add layer entry to manifest with path-style name
|
||||
3. Copy config files (*.json) as config layers
|
||||
4. Write manifest
|
||||
```
|
||||
|
||||
## FP8 Quantization
|
||||
|
||||
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image-fp8 --quantize fp8
|
||||
```
|
||||
|
||||
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
|
||||
|
||||
- All created arrays are automatically tracked
|
||||
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
|
||||
- Call `.Free()` when done with an array
|
||||
|
||||
231
x/imagegen/api/handler.go
Normal file
231
x/imagegen/api/handler.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
// URL format not supported when using base64 transfer
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
} else {
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
31
x/imagegen/api/types.go
Normal file
31
x/imagegen/api/types.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
197
x/imagegen/cache/teacache.go
vendored
Normal file
197
x/imagegen/cache/teacache.go
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
|
||||
// It caches the transformer output and reuses it when timestep values
|
||||
// are similar between consecutive steps.
|
||||
//
|
||||
// For CFG (classifier-free guidance), it caches pos and neg predictions
|
||||
// separately and always computes CFG fresh to avoid error amplification.
|
||||
//
|
||||
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
|
||||
// https://github.com/ali-vilab/TeaCache
|
||||
type TeaCache struct {
|
||||
// Cached transformer output from last computed step (non-CFG mode)
|
||||
cachedOutput *mlx.Array
|
||||
|
||||
// Cached CFG outputs (pos and neg separately)
|
||||
cachedPosOutput *mlx.Array
|
||||
cachedNegOutput *mlx.Array
|
||||
|
||||
// Previous timestep value for difference calculation
|
||||
prevTimestep float32
|
||||
|
||||
// Accumulated difference for rescaling
|
||||
accumulatedDiff float32
|
||||
|
||||
// Configuration
|
||||
threshold float32 // Threshold for recomputation decision
|
||||
rescaleFactor float32 // Model-specific rescaling factor
|
||||
skipEarlySteps int // Number of early steps to never cache
|
||||
|
||||
// Statistics
|
||||
cacheHits int
|
||||
cacheMisses int
|
||||
}
|
||||
|
||||
// TeaCacheConfig holds configuration for TeaCache.
|
||||
type TeaCacheConfig struct {
|
||||
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
|
||||
// Recommended: 0.05-0.15 for image models
|
||||
Threshold float32
|
||||
|
||||
// Rescale factor to adjust timestep embedding differences.
|
||||
// Model-specific, typically 1.0-2.0
|
||||
RescaleFactor float32
|
||||
|
||||
// SkipEarlySteps: number of early steps to always compute (never cache).
|
||||
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
|
||||
SkipEarlySteps int
|
||||
}
|
||||
|
||||
// DefaultTeaCacheConfig returns default configuration for TeaCache.
|
||||
func DefaultTeaCacheConfig() *TeaCacheConfig {
|
||||
return &TeaCacheConfig{
|
||||
Threshold: 0.1,
|
||||
RescaleFactor: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTeaCache creates a new TeaCache instance.
|
||||
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
|
||||
if cfg == nil {
|
||||
cfg = DefaultTeaCacheConfig()
|
||||
}
|
||||
return &TeaCache{
|
||||
threshold: cfg.Threshold,
|
||||
rescaleFactor: cfg.RescaleFactor,
|
||||
skipEarlySteps: cfg.SkipEarlySteps,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldCompute determines if we should compute the full forward pass
|
||||
// or reuse the cached output based on timestep similarity.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. First step always computes
|
||||
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
|
||||
// 3. If accumulated difference > threshold, compute new output
|
||||
// 4. Otherwise, reuse cached output
|
||||
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
|
||||
// Always compute early steps (critical for structure)
|
||||
// Check both regular cache and CFG cache
|
||||
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
|
||||
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
// Compute absolute difference between current and previous timestep
|
||||
diff := timestep - tc.prevTimestep
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
|
||||
// Apply rescaling factor
|
||||
scaledDiff := diff * tc.rescaleFactor
|
||||
|
||||
// Accumulate difference (helps track drift over multiple cached steps)
|
||||
tc.accumulatedDiff += scaledDiff
|
||||
|
||||
// Decision based on accumulated difference
|
||||
if tc.accumulatedDiff > tc.threshold {
|
||||
tc.accumulatedDiff = 0 // Reset accumulator
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
|
||||
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
|
||||
// Free previous cached output
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedOutput = output
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
|
||||
// This allows CFG to be computed fresh each step, avoiding error amplification.
|
||||
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
|
||||
// Free previous cached outputs
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedPosOutput = posOutput
|
||||
tc.cachedNegOutput = negOutput
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// GetCached returns the cached output (non-CFG mode).
|
||||
func (tc *TeaCache) GetCached() *mlx.Array {
|
||||
tc.cacheHits++
|
||||
return tc.cachedOutput
|
||||
}
|
||||
|
||||
// GetCFGCached returns cached pos and neg outputs for CFG mode.
|
||||
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
|
||||
tc.cacheHits++
|
||||
return tc.cachedPosOutput, tc.cachedNegOutput
|
||||
}
|
||||
|
||||
// HasCFGCache returns true if CFG cache is available.
|
||||
func (tc *TeaCache) HasCFGCache() bool {
|
||||
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
|
||||
}
|
||||
|
||||
// Arrays returns all arrays that should be kept alive.
|
||||
func (tc *TeaCache) Arrays() []*mlx.Array {
|
||||
var arrays []*mlx.Array
|
||||
if tc.cachedOutput != nil {
|
||||
arrays = append(arrays, tc.cachedOutput)
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
arrays = append(arrays, tc.cachedPosOutput)
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
arrays = append(arrays, tc.cachedNegOutput)
|
||||
}
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Stats returns cache hit/miss statistics.
|
||||
func (tc *TeaCache) Stats() (hits, misses int) {
|
||||
return tc.cacheHits, tc.cacheMisses
|
||||
}
|
||||
|
||||
// Free releases all cached arrays.
|
||||
func (tc *TeaCache) Free() {
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
tc.cachedOutput = nil
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
tc.cachedPosOutput = nil
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
tc.cachedNegOutput = nil
|
||||
}
|
||||
}
|
||||
533
x/imagegen/cli.go
Normal file
533
x/imagegen/cli.go
Normal file
@@ -0,0 +1,533 @@
|
||||
// cli.go provides CLI commands for image generation models.
|
||||
//
|
||||
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
|
||||
// Currently these are separate to keep experimental code isolated.
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
)
|
||||
|
||||
// ImageGenOptions holds options for image generation.
|
||||
// These can be set via environment variables or interactive commands.
|
||||
type ImageGenOptions struct {
|
||||
Width int
|
||||
Height int
|
||||
Steps int
|
||||
Seed int
|
||||
NegativePrompt string
|
||||
}
|
||||
|
||||
// DefaultOptions returns the default image generation options.
|
||||
func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
cmd.Flags().MarkHidden("height")
|
||||
cmd.Flags().MarkHidden("steps")
|
||||
cmd.Flags().MarkHidden("seed")
|
||||
cmd.Flags().MarkHidden("negative")
|
||||
}
|
||||
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
if cmd != nil && cmd.Flags() != nil {
|
||||
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
|
||||
opts.Width = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
|
||||
opts.Height = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
|
||||
opts.Steps = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
|
||||
opts.Seed = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
|
||||
opts.NegativePrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
if interactive {
|
||||
return runInteractive(cmd, name, keepAlive, opts)
|
||||
}
|
||||
|
||||
// One-shot generation
|
||||
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
}
|
||||
|
||||
// Show loading spinner until generation starts
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 and save to CWD
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Create filename from prompt
|
||||
safeName := sanitizeFilename(prompt)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to save image: %w", err)
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runInteractive runs an interactive REPL for image generation.
|
||||
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Describe an image to generate (/help for commands)",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if envconfig.NoHistory() {
|
||||
scanner.HistoryDisable()
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle commands
|
||||
switch {
|
||||
case strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
|
||||
printInteractiveHelp(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set "):
|
||||
if err := handleSetCommand(line[5:], &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
printCurrentSettings(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
}
|
||||
|
||||
// Show loading spinner until generation starts
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Save image to current directory with descriptive name
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create filename from prompt (sanitized)
|
||||
safeName := sanitizeFilename(line)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeFilename removes characters that aren't safe for filenames.
|
||||
func sanitizeFilename(s string) string {
|
||||
s = strings.ToLower(s)
|
||||
s = strings.ReplaceAll(s, " ", "-")
|
||||
// Remove any character that's not alphanumeric or hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range s {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
|
||||
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
|
||||
fmt.Fprintln(os.Stderr, " /show Show current settings")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
// printCurrentSettings prints the current image generation settings.
|
||||
func printCurrentSettings(opts ImageGenOptions) {
|
||||
fmt.Fprintf(os.Stderr, "Current settings:\n")
|
||||
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
|
||||
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
|
||||
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
|
||||
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
|
||||
if opts.NegativePrompt != "" {
|
||||
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
// handleSetCommand handles /set commands to change options.
|
||||
func handleSetCommand(args string, opts *ImageGenOptions) error {
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("usage: /set <option> <value>")
|
||||
}
|
||||
|
||||
key := strings.ToLower(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch key {
|
||||
case "width", "w":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("width must be a positive integer")
|
||||
}
|
||||
opts.Width = v
|
||||
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
|
||||
case "height", "h":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("height must be a positive integer")
|
||||
}
|
||||
opts.Height = v
|
||||
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
|
||||
case "steps", "s":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("steps must be a positive integer")
|
||||
}
|
||||
opts.Steps = v
|
||||
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
|
||||
case "seed":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seed must be an integer")
|
||||
}
|
||||
opts.Seed = v
|
||||
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
|
||||
case "negative", "neg", "n":
|
||||
opts.NegativePrompt = value
|
||||
if value == "" {
|
||||
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown option: %s (try /help)", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// displayImageInTerminal attempts to render an image inline in the terminal.
|
||||
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
|
||||
// Returns true if the image was displayed, false otherwise.
|
||||
func displayImageInTerminal(imagePath string) bool {
|
||||
// Check if terminal supports inline images
|
||||
termProgram := os.Getenv("TERM_PROGRAM")
|
||||
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
|
||||
weztermPane := os.Getenv("WEZTERM_PANE")
|
||||
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
|
||||
|
||||
// Read the image file
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
switch {
|
||||
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
|
||||
// iTerm2/WezTerm inline image protocol
|
||||
// ESC ] 1337 ; File = [arguments] : base64 BEL
|
||||
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
|
||||
return true
|
||||
|
||||
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
|
||||
// Kitty graphics protocol (also used by Ghostty)
|
||||
// Send in chunks for large images
|
||||
const chunkSize = 4096
|
||||
for i := 0; i < len(encoded); i += chunkSize {
|
||||
end := min(i+chunkSize, len(encoded))
|
||||
chunk := encoded[i:end]
|
||||
|
||||
if i == 0 {
|
||||
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
|
||||
more := 1
|
||||
if end >= len(encoded) {
|
||||
more = 0
|
||||
}
|
||||
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
|
||||
} else if end >= len(encoded) {
|
||||
// Last chunk
|
||||
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
|
||||
} else {
|
||||
// Middle chunk
|
||||
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
190
x/imagegen/client/create.go
Normal file
190
x/imagegen/client/create.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
// When quantize is true, returns multiple layers (weight + scales)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
// Check if quantization is supported
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales (use _scale suffix convention)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name, // Keep original name for weight
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale", // Add _scale suffix
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, imagegen.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias", // Add _qbias suffix
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// Non-quantized path: just create a single layer
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
120
x/imagegen/client/quantize.go
Normal file
120
x/imagegen/client/quantize.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize with affine mode: group_size=32, bits=8
|
||||
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
func ensureTempDir() string {
|
||||
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
18
x/imagegen/client/quantize_stub.go
Normal file
18
x/imagegen/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
35
x/imagegen/cmd/engine/README.md
Normal file
35
x/imagegen/cmd/engine/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# MLX Engine
|
||||
|
||||
Experimental MLX backend for running models on Apple Silicon and CUDA.
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
||||
```
|
||||
|
||||
## Text Generation
|
||||
|
||||
```bash
|
||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-temperature` - sampling temperature (default 0.7)
|
||||
- `-top-p` - nucleus sampling (default 0.9)
|
||||
- `-top-k` - top-k sampling (default 40)
|
||||
|
||||
Supports: Llama, Gemma3, GPT-OSS
|
||||
|
||||
## Image Generation
|
||||
|
||||
```bash
|
||||
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-width`, `-height` - image dimensions (default 1024x1024)
|
||||
- `-steps` - denoising steps (default 9)
|
||||
- `-seed` - random seed (default 42)
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
@@ -61,12 +63,16 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -98,14 +104,45 @@ func main() {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
TeaCache: *teaCache,
|
||||
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||
FusedQKV: *fusedQKV,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *glmImageFlag:
|
||||
m := &glm_image.Model{}
|
||||
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
|
||||
var loadErr error
|
||||
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
|
||||
loadErr = m.LoadFromPath(*modelPath)
|
||||
} else {
|
||||
loadErr = m.Load(*modelPath)
|
||||
}
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
MaxVisualTokens: int32(*maxTokens),
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
|
||||
216
x/imagegen/create.go
Normal file
216
x/imagegen/create.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find all safetensors files in this component
|
||||
entries, err := os.ReadDir(componentDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", component, err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(componentDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize == "fp8" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to fp8"
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
|
||||
// Determine if this tensor should be quantized
|
||||
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"vision_language_encoder/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"processor/tokenizer.json", // GLM-Image main tokenizer
|
||||
"processor/tokenizer_config.json", // GLM-Image tokenizer config
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Rename _class_name to architecture, remove diffusers-specific fields
|
||||
if className, ok := cfg["_class_name"]; ok {
|
||||
cfg["architecture"] = className
|
||||
delete(cfg, "_class_name")
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info
|
||||
if quantize == "fp8" {
|
||||
cfg["quantization"] = "FP8"
|
||||
} else {
|
||||
cfg["quantization"] = "BF16"
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
}
|
||||
r = bytes.NewReader(data)
|
||||
} else {
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
r = f
|
||||
}
|
||||
|
||||
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use model_index.json as the config layer
|
||||
if cfgPath == "model_index.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("model_index.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
110
x/imagegen/image.go
Normal file
110
x/imagegen/image.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SaveImage saves an MLX array as a PNG image file.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func SaveImage(arr *mlx.Array, path string) error {
|
||||
img, err := ArrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func EncodeImageBase64(arr *mlx.Array) (string, error) {
|
||||
img, err := ArrayToImage(arr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// ArrayToImage converts an MLX array to a Go image.RGBA.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
// Free intermediate arrays to avoid memory leak
|
||||
squeezed := mlx.Squeeze(arr, 0)
|
||||
transposed := mlx.Transpose(squeezed, 1, 2, 0)
|
||||
squeezed.Free()
|
||||
img := mlx.Contiguous(transposed)
|
||||
transposed.Free()
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
19
x/imagegen/imagegen.md
Normal file
19
x/imagegen/imagegen.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Image generation models (experimental)
|
||||
|
||||
Experimental image generation models are available for **macOS** in Ollama:
|
||||
|
||||
## Available models
|
||||
|
||||
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
|
||||
|
||||
```
|
||||
ollama run x/z-image-turbo
|
||||
```
|
||||
|
||||
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
|
||||
|
||||
More models coming soon:
|
||||
|
||||
1. Qwen-Image-2512
|
||||
2. Qwen-Image-Edit-2511
|
||||
3. GLM-Image
|
||||
177
x/imagegen/manifest.go
Normal file
177
x/imagegen/manifest.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"` // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ModelManifest holds a parsed manifest with helper methods.
|
||||
type ModelManifest struct {
|
||||
Manifest *Manifest
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
func DefaultManifestDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
||||
func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("parse manifest: %w", err)
|
||||
}
|
||||
|
||||
return &ModelManifest{
|
||||
Manifest: &manifest,
|
||||
BlobDir: DefaultBlobDir(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
// Parse model name into components
|
||||
// Default: registry.ollama.ai/library/<name>/<tag>
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
// Handle explicit tag
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
// Handle full path like "host/namespace/name"
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// BlobPath returns the full path to a blob given its digest.
|
||||
func (m *ModelManifest) BlobPath(digest string) string {
|
||||
// Convert "sha256:abc123" to "sha256-abc123"
|
||||
blobName := strings.Replace(digest, ":", "-", 1)
|
||||
return filepath.Join(m.BlobDir, blobName)
|
||||
}
|
||||
|
||||
// GetTensorLayers returns all tensor layers for a given component.
|
||||
// Component should be "text_encoder", "transformer", or "vae".
|
||||
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
|
||||
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
|
||||
prefix := component + "/"
|
||||
var layers []ManifestLayer
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
return layers
|
||||
}
|
||||
|
||||
// GetConfigLayer returns the config layer for a given path.
|
||||
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
return &layer
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadConfig reads and returns the content of a config file.
|
||||
func (m *ModelManifest) ReadConfig(configPath string) ([]byte, error) {
|
||||
layer := m.GetConfigLayer(configPath)
|
||||
if layer == nil {
|
||||
return nil, fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
return os.ReadFile(blobPath)
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config file.
|
||||
func (m *ModelManifest) ReadConfigJSON(configPath string, v any) error {
|
||||
data, err := m.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// OpenBlob opens a blob for reading.
|
||||
func (m *ModelManifest) OpenBlob(digest string) (io.ReadCloser, error) {
|
||||
return os.Open(m.BlobPath(digest))
|
||||
}
|
||||
|
||||
// HasTensorLayers returns true if the manifest has any tensor layers.
|
||||
func (m *ModelManifest) HasTensorLayers() bool {
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
103
x/imagegen/memory.go
Normal file
103
x/imagegen/memory.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Package imagegen provides experimental image generation capabilities for Ollama.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// The goal is to integrate these capabilities into the main Ollama packages once
|
||||
// the format is stable.
|
||||
//
|
||||
// TODO (jmorganca): Integrate into main packages when stable:
|
||||
// - CLI commands → cmd/
|
||||
// - API endpoints → api/
|
||||
// - Model creation → server/
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GB is a convenience constant for gigabytes.
|
||||
const GB = 1024 * 1024 * 1024
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
// Returns nil if supported, or an error describing why it's not supported.
|
||||
func CheckPlatformSupport() error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: Metal is supported via MLX
|
||||
if runtime.GOARCH != "arm64" {
|
||||
return fmt.Errorf("image generation on macOS requires Apple Silicon (arm64), got %s", runtime.GOARCH)
|
||||
}
|
||||
return nil
|
||||
case "linux", "windows":
|
||||
// Linux/Windows: CUDA support (requires mlx or cuda build)
|
||||
// The actual backend availability is checked at runtime
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("image generation is not supported on %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
||||
// Returns nil if memory is sufficient, or an error if not.
|
||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
||||
required := EstimateVRAM(modelName)
|
||||
if availableMemory < required {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
required/GB, availableMemory/GB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err == nil && manifest.HasTensorLayers() {
|
||||
return modelName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
110
x/imagegen/memory_test.go
Normal file
110
x/imagegen/memory_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/non-arm64")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMemoryRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
availableMemory uint64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sufficient memory",
|
||||
availableMemory: 32 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exactly enough memory",
|
||||
availableMemory: 21 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insufficient memory",
|
||||
availableMemory: 16 * GB,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory",
|
||||
availableMemory: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use a non-existent model name which will default to 21GB estimate
|
||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
||||
t.Errorf("Missing VRAM estimate for %s", name)
|
||||
} else if actual != expectedVRAM {
|
||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateVRAMDefault(t *testing.T) {
|
||||
// Non-existent model should return default 21GB
|
||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
||||
if vram != 21*GB {
|
||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
if result != "" {
|
||||
t.Errorf("ResolveModelName() = %q, want empty string", result)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,10 @@ package mlx
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
// Forward declare cpu_stream
|
||||
static mlx_stream cpu_stream();
|
||||
|
||||
// Cached default GPU stream for all ops
|
||||
static mlx_stream _default_stream = {0};
|
||||
@@ -603,6 +607,11 @@ func (a *Array) Valid() bool {
|
||||
return a != nil && a.c.ctx != nil
|
||||
}
|
||||
|
||||
// Kept returns true if the array is marked to survive Eval() cleanup.
|
||||
func (a *Array) Kept() bool {
|
||||
return a != nil && a.kept
|
||||
}
|
||||
|
||||
func int32ToCInt(s []int32) *C.int {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
@@ -1026,10 +1035,11 @@ func View(a *Array, dtype int) *Array {
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Contiguous returns a contiguous copy of the array
|
||||
// Contiguous returns a contiguous copy of the array (row-major)
|
||||
func Contiguous(a *Array) *Array {
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_contiguous(&res, a.c, true, C.default_stream())
|
||||
// Use allow_col=false to force row-major contiguous layout
|
||||
C.mlx_contiguous(&res, a.c, false, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
@@ -1475,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
|
||||
return int32(val)
|
||||
}
|
||||
|
||||
// Bytes copies the raw bytes out of the array without type conversion.
|
||||
// Works with common dtypes (float32, int32, uint32, uint8).
|
||||
// For non-contiguous arrays, call Contiguous() first.
|
||||
// Note: Triggers cleanup of non-kept arrays.
|
||||
func (a *Array) Bytes() []byte {
|
||||
cleanup()
|
||||
nbytes := a.Nbytes()
|
||||
if nbytes == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get raw pointer based on dtype
|
||||
var ptr unsafe.Pointer
|
||||
switch a.Dtype() {
|
||||
case DtypeFloat32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
|
||||
case DtypeInt32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
|
||||
case DtypeUint32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
|
||||
case DtypeUint8:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
|
||||
default:
|
||||
// For other types (bf16, f16, etc), convert to float32
|
||||
arr := AsType(a, DtypeFloat32)
|
||||
arr.Eval()
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
|
||||
nbytes = arr.Nbytes()
|
||||
}
|
||||
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
data := make([]byte, nbytes)
|
||||
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
|
||||
return data
|
||||
}
|
||||
|
||||
// ============ Utility ============
|
||||
|
||||
// String returns a string representation
|
||||
@@ -1653,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
|
||||
// This correctly handles all dtypes including uint32 for quantized weights.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Create the map
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
// Add each array to the map
|
||||
for name, arr := range arrays {
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
// Create empty metadata (optional)
|
||||
cMeta := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMeta)
|
||||
|
||||
// Save
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============ NPY Loading ============
|
||||
|
||||
// LoadNpy loads a numpy array from an npy file
|
||||
@@ -1762,11 +1838,16 @@ func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
|
||||
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal (Gaussian) tensor
|
||||
// RandomNormal creates a random normal (Gaussian) tensor in float32
|
||||
func RandomNormal(shape []int32, seed uint64) *Array {
|
||||
return RandomNormalWithDtype(shape, seed, DtypeFloat32)
|
||||
}
|
||||
|
||||
// RandomNormalWithDtype creates a random normal (Gaussian) tensor with specified dtype
|
||||
func RandomNormalWithDtype(shape []int32, seed uint64, dtype Dtype) *Array {
|
||||
key := RandomKey(seed)
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
|
||||
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), 0.0, 1.0, key.c, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
@@ -1976,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
||||
// Returns (quantized_weights, scales, biases).
|
||||
// groupSize: number of elements quantized together (default 64)
|
||||
// bits: bits per element, 2, 4, or 8 (default 4)
|
||||
// mode: "affine" (default) or "mxfp4"
|
||||
// mode: "affine" (default), "mxfp4", or "mxfp8"
|
||||
// Note: mxfp8 mode returns nil biases (only weights and scales)
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
@@ -1985,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
res := C.mlx_vector_array_new()
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||
|
||||
// Result is a vector of 3 arrays: [weights, scales, biases]
|
||||
// Result is a vector of arrays: [weights, scales, biases?]
|
||||
// mxfp8 mode returns only 2 elements (no biases)
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
var w0, w1, w2 C.mlx_array
|
||||
C.mlx_vector_array_get(&w0, res, 0)
|
||||
C.mlx_vector_array_get(&w1, res, 1)
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
if vecSize >= 3 {
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
}
|
||||
C.mlx_vector_array_free(res)
|
||||
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
if vecSize >= 3 {
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
}
|
||||
return newArray(w0), newArray(w1), nil
|
||||
}
|
||||
|
||||
// Dequantize reconstructs weights from quantized form.
|
||||
|
||||
693
x/imagegen/models/glm_image/glm_image.go
Normal file
693
x/imagegen/models/glm_image/glm_image.go
Normal file
@@ -0,0 +1,693 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
|
||||
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
|
||||
// Special tokens: 0=pad, 1=eos, 2=unk
|
||||
type ByT5Tokenizer struct {
|
||||
PadTokenID int32
|
||||
EOSTokenID int32
|
||||
UNKTokenID int32
|
||||
}
|
||||
|
||||
// NewByT5Tokenizer creates a new ByT5 tokenizer
|
||||
func NewByT5Tokenizer() *ByT5Tokenizer {
|
||||
return &ByT5Tokenizer{
|
||||
PadTokenID: 0,
|
||||
EOSTokenID: 1,
|
||||
UNKTokenID: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode converts a string to token IDs
|
||||
func (t *ByT5Tokenizer) Encode(text string) []int32 {
|
||||
bytes := []byte(text)
|
||||
tokens := make([]int32, len(bytes))
|
||||
for i, b := range bytes {
|
||||
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
|
||||
// (tokens 0, 1, 2 are PAD, EOS, UNK)
|
||||
tokens[i] = int32(b) + 3
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
|
||||
bytes := make([]byte, 0, len(tokens))
|
||||
for _, tok := range tokens {
|
||||
if tok >= 3 && tok < 259 {
|
||||
bytes = append(bytes, byte(tok-3))
|
||||
}
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.5)
|
||||
Width int32 // Image width (default: 1024, must be divisible by 32)
|
||||
Height int32 // Image height (default: 1024, must be divisible by 32)
|
||||
Steps int // Diffusion denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// AR generation options
|
||||
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
|
||||
Temperature float32 // AR sampling temperature (default: 0.9)
|
||||
TopP float32 // Nucleus sampling (default: 0.75)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with stage and step progress.
|
||||
type ProgressFunc func(stage string, step, totalSteps int)
|
||||
|
||||
// Model represents a GLM-Image hybrid model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
|
||||
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
|
||||
TextEncoder *T5TextEncoder
|
||||
VisionLanguageEncoder *VisionLanguageEncoder
|
||||
Transformer *DiffusionTransformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the GLM-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
// Used for T5 text encoder (glyph embeddings)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizer(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder (~830MB)
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder (~19GB, 9B params)
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer (~13GB, 7B params)
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder (~775MB)
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the model from a directory path (not ollama manifest)
|
||||
func (m *Model) LoadFromPath(modelPath string) error {
|
||||
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelPath
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizerFromPath(modelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal generation pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.5
|
||||
}
|
||||
// Calculate MaxVisualTokens based on image dimensions
|
||||
// GLM-Image generates TWO grids of visual tokens:
|
||||
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
|
||||
// 2. Then: target (large) grid - tokenH × tokenW tokens
|
||||
// After generation, we extract only the TARGET grid tokens for diffusion.
|
||||
factor := int32(32)
|
||||
tokenH := cfg.Height / factor
|
||||
tokenW := cfg.Width / factor
|
||||
targetGridTokens := tokenH * tokenW
|
||||
|
||||
// Compute prev grid dimensions using diffusers formula:
|
||||
// ratio = token_h / token_w
|
||||
// prev_token_h = int(sqrt(ratio) * 16)
|
||||
// prev_token_w = int(sqrt(1/ratio) * 16)
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
|
||||
prevGridTokens := prevTokenH * prevTokenW
|
||||
|
||||
// Total tokens to generate = prev grid + target grid
|
||||
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
|
||||
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
|
||||
if cfg.Temperature <= 0 {
|
||||
cfg.Temperature = 0.9
|
||||
}
|
||||
if cfg.TopP <= 0 {
|
||||
cfg.TopP = 0.75
|
||||
}
|
||||
|
||||
// Ensure dimensions are divisible by 32
|
||||
cfg.Width = (cfg.Width / 32) * 32
|
||||
cfg.Height = (cfg.Height / 32) * 32
|
||||
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
|
||||
// Progress callback helper
|
||||
progress := func(stage string, step, total int) {
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(stage, step, total)
|
||||
}
|
||||
}
|
||||
|
||||
// === PHASE 1: T5 Text Encoding ===
|
||||
fmt.Println("[T5] Encoding glyph text...")
|
||||
progress("text_encoding", 0, 1)
|
||||
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
mlx.Keep(textEmbed)
|
||||
mlx.Eval(textEmbed)
|
||||
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
|
||||
progress("text_encoding", 1, 1)
|
||||
|
||||
// === PHASE 2: AR Visual Token Generation ===
|
||||
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
|
||||
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
|
||||
visualTokens := m.VisionLanguageEncoder.Generate(
|
||||
cfg.Prompt,
|
||||
m.GLMTokenizer,
|
||||
cfg.MaxVisualTokens,
|
||||
cfg.Temperature,
|
||||
cfg.TopP,
|
||||
cfg.Seed,
|
||||
cfg.Height,
|
||||
cfg.Width,
|
||||
func(step int) {
|
||||
if step%100 == 0 || step < 10 {
|
||||
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
|
||||
}
|
||||
progress("ar_generation", step, int(cfg.MaxVisualTokens))
|
||||
},
|
||||
)
|
||||
mlx.Keep(visualTokens)
|
||||
mlx.Eval(visualTokens)
|
||||
fmt.Printf("[AR] Done generating visual tokens\n")
|
||||
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
|
||||
|
||||
vtShape := visualTokens.Shape()
|
||||
totalGenerated := vtShape[1]
|
||||
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
|
||||
|
||||
// Extract only the TARGET grid tokens (skip the prev grid tokens)
|
||||
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
|
||||
// large_image_start_offset = prev_grid_size
|
||||
var targetGridVisualTokens *mlx.Array
|
||||
if totalGenerated >= prevGridTokens+targetGridTokens {
|
||||
// Full generation completed - extract target grid
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, prevGridTokens + targetGridTokens})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
} else if totalGenerated > prevGridTokens {
|
||||
// Partial target grid - take what we have
|
||||
actualTargetTokens := totalGenerated - prevGridTokens
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, totalGenerated})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
|
||||
actualTargetTokens, targetGridTokens)
|
||||
} else {
|
||||
// Not enough tokens - EOS came too early
|
||||
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
|
||||
totalGenerated, prevGridTokens)
|
||||
}
|
||||
|
||||
// === PHASE 3: Diffusion Decoding ===
|
||||
// Setup scheduler with dynamic shift based on image size
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
|
||||
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Initialize noise latents [B, C, H, W]
|
||||
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
|
||||
// target_grid tokens -> 2x upsample -> patch_count
|
||||
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
|
||||
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
|
||||
|
||||
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
|
||||
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
|
||||
mlx.Keep(priorEmbed)
|
||||
mlx.Eval(priorEmbed)
|
||||
|
||||
// Prepare text conditioning (project T5 embeddings)
|
||||
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
|
||||
mlx.Keep(textCond)
|
||||
mlx.Eval(textCond)
|
||||
|
||||
// === CFG Setup ===
|
||||
// For classifier-free guidance, we need unconditional (negative) text embeddings
|
||||
// GLM-Image uses empty string "" for negative prompt
|
||||
doCFG := cfg.GuidanceScale > 1.0
|
||||
var negativeTextCond *mlx.Array
|
||||
if doCFG {
|
||||
// Encode empty string for negative prompt
|
||||
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
|
||||
mlx.Keep(negativeTextEmbed)
|
||||
mlx.Eval(negativeTextEmbed)
|
||||
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
|
||||
mlx.Keep(negativeTextCond)
|
||||
mlx.Eval(negativeTextCond)
|
||||
negativeTextEmbed.Free()
|
||||
}
|
||||
|
||||
// Prepare conditioning inputs
|
||||
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
|
||||
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
|
||||
targetSize = mlx.ToBFloat16(targetSize)
|
||||
cropCoords = mlx.ToBFloat16(cropCoords)
|
||||
mlx.Keep(targetSize)
|
||||
mlx.Keep(cropCoords)
|
||||
mlx.Eval(targetSize, cropCoords)
|
||||
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
|
||||
progress("diffusion", 0, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
latents.Free()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestep value for the transformer
|
||||
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
|
||||
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
|
||||
timestepVal := scheduler.Timesteps[i] - 1
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
|
||||
|
||||
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
// Transformer forward with MMDiT architecture
|
||||
// Conditional pass (with text + prior embeddings)
|
||||
outputCond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed,
|
||||
textCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
false, // priorTokenDrop = false for conditional
|
||||
)
|
||||
|
||||
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
|
||||
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
var noisePred *mlx.Array
|
||||
if doCFG {
|
||||
// Unconditional pass (empty text, dropped prior embeddings)
|
||||
outputUncond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
|
||||
negativeTextCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
true, // priorTokenDrop = true for unconditional
|
||||
)
|
||||
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
|
||||
diff := mlx.Sub(noisePredCond, noisePredUncond)
|
||||
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
|
||||
noisePred = mlx.Add(noisePredUncond, scaled)
|
||||
} else {
|
||||
noisePred = noisePredCond
|
||||
}
|
||||
|
||||
// Scheduler step
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
progress("diffusion", i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// Cleanup intermediate arrays
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
if negativeTextCond != nil {
|
||||
negativeTextCond.Free()
|
||||
}
|
||||
targetSize.Free()
|
||||
cropCoords.Free()
|
||||
|
||||
// === PHASE 4: VAE Decode ===
|
||||
progress("vae_decode", 0, 1)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
mlx.Eval(decoded)
|
||||
latents.Free()
|
||||
progress("vae_decode", 1, 1)
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
|
||||
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
|
||||
// scale must be 2 or 4
|
||||
//
|
||||
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
|
||||
// missing tokens are padded with 0 (visual token padding value).
|
||||
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
|
||||
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
|
||||
// Each token at (i, j) becomes scale*scale tokens in the output
|
||||
|
||||
mlx.Eval(tokens)
|
||||
tokenData := tokens.DataInt32()
|
||||
numTokens := int32(len(tokenData))
|
||||
expectedTokens := prevH * prevW
|
||||
|
||||
// Warn if we got fewer tokens than expected (early EOS)
|
||||
if numTokens < expectedTokens {
|
||||
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
|
||||
numTokens, expectedTokens)
|
||||
}
|
||||
|
||||
targetH := prevH * scale
|
||||
targetW := prevW * scale
|
||||
upsampled := make([]int32, targetH*targetW)
|
||||
|
||||
for i := int32(0); i < prevH; i++ {
|
||||
for j := int32(0); j < prevW; j++ {
|
||||
srcIdx := i*prevW + j
|
||||
|
||||
// Handle early EOS: use 0 (padding) for missing tokens
|
||||
var val int32
|
||||
if srcIdx < numTokens {
|
||||
val = tokenData[srcIdx]
|
||||
} else {
|
||||
val = 0 // Padding token
|
||||
}
|
||||
|
||||
// Place in scale*scale positions
|
||||
dstI := i * scale
|
||||
dstJ := j * scale
|
||||
for di := int32(0); di < scale; di++ {
|
||||
for dj := int32(0); dj < scale; dj++ {
|
||||
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
|
||||
}
|
||||
|
||||
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// Transpose: -> [B, pH, pW, C, p, p]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// Flatten: -> [B, pH*pW, C*p*p]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
|
||||
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// Transpose: -> [B, C, pH, p, pW, p]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// Reshape: -> [B, C, H, W]
|
||||
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
return m*cfg.MaxShift + cfg.BaseShift
|
||||
}
|
||||
|
||||
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
|
||||
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
|
||||
// This matches diffusers' _upsample_token_ids function
|
||||
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B := shape[0]
|
||||
|
||||
// Reshape to [B, 1, H, W] for interpolation
|
||||
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
|
||||
|
||||
// Convert to float for interpolation
|
||||
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
|
||||
|
||||
// 2x nearest neighbor upsample
|
||||
// [B, 1, H, W] -> [B, 1, H*2, W*2]
|
||||
upsampled := nearestUpsample2x(tokensFloat)
|
||||
|
||||
// Convert back to int and reshape to [B, H*2*W*2]
|
||||
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
|
||||
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
|
||||
}
|
||||
|
||||
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
|
||||
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Repeat each element 2x2
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
|
||||
// Tile to repeat each pixel 2x2
|
||||
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
|
||||
|
||||
// Reshape to final size
|
||||
return mlx.Reshape(x, B, C, H*2, W*2)
|
||||
}
|
||||
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
@@ -0,0 +1,358 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// GLMTokenizer implements the GLM tokenizer for the AR model
|
||||
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
|
||||
// greedy longest-match tokenization from the vocab without runtime merging.
|
||||
type GLMTokenizer struct {
|
||||
Vocab map[string]int32 // token string -> token ID
|
||||
VocabReverse map[int32]string // token ID -> token string
|
||||
SpecialTokens map[string]int32 // special token strings -> IDs
|
||||
|
||||
// Special token IDs
|
||||
SopTokenID int32 // <sop> = grid_bos_token (167845)
|
||||
EopTokenID int32 // <eop> = grid_eos_token (167846)
|
||||
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
|
||||
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
|
||||
PadTokenID int32
|
||||
|
||||
// Sorted vocab keys by length (longest first) for greedy matching
|
||||
sortedTokens []string
|
||||
}
|
||||
|
||||
// tokenizerJSON represents the structure of tokenizer.json
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
|
||||
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory in manifest
|
||||
data, err := manifest.ReadConfig("processor/tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
|
||||
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory
|
||||
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
|
||||
data, err := os.ReadFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// Encode tokenizes a string into token IDs
|
||||
// This uses greedy longest-match tokenization with GPT-2 style space handling
|
||||
func (t *GLMTokenizer) Encode(text string) []int32 {
|
||||
if text == "" {
|
||||
return []int32{}
|
||||
}
|
||||
|
||||
var tokens []int32
|
||||
|
||||
// First, check for and handle special tokens
|
||||
// Replace special tokens with placeholders, encode, then restore
|
||||
specialReplacements := make(map[string]int32)
|
||||
for special, id := range t.SpecialTokens {
|
||||
if strings.Contains(text, special) {
|
||||
specialReplacements[special] = id
|
||||
}
|
||||
}
|
||||
|
||||
// Process text character by character with special token handling
|
||||
i := 0
|
||||
isFirstToken := true
|
||||
|
||||
for i < len(text) {
|
||||
// Check for special tokens first
|
||||
foundSpecial := false
|
||||
for special, id := range specialReplacements {
|
||||
if strings.HasPrefix(text[i:], special) {
|
||||
tokens = append(tokens, id)
|
||||
i += len(special)
|
||||
isFirstToken = false
|
||||
foundSpecial = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular text with GPT-2 style space prefix
|
||||
// "Ġ" (U+0120) represents a space before a token
|
||||
remaining := text[i:]
|
||||
|
||||
// Try to find the longest matching token
|
||||
matched := false
|
||||
for _, token := range t.sortedTokens {
|
||||
// Skip special tokens in regular matching
|
||||
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this token matches
|
||||
tokenText := token
|
||||
|
||||
// Handle the Ġ prefix (represents space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
// This token expects a leading space
|
||||
if i > 0 || !isFirstToken {
|
||||
// Check if remaining starts with space + token content
|
||||
tokenContent := token[len("Ġ"):]
|
||||
if strings.HasPrefix(remaining, " "+tokenContent) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += 1 + len(tokenContent) // space + content
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular token without space prefix
|
||||
if strings.HasPrefix(remaining, tokenText) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += len(tokenText)
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
// No token found - skip this character (or use UNK)
|
||||
// For now, just skip unknown characters
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// EncodeForGeneration encodes a prompt with grid tokens for image generation
|
||||
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
//
|
||||
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
|
||||
// space prefix), matching the HuggingFace tokenizer behavior.
|
||||
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
|
||||
// Calculate grid dimensions
|
||||
factor := int32(32)
|
||||
height := (targetHeight / factor) * factor
|
||||
width := (targetWidth / factor) * factor
|
||||
tokenH := height / factor
|
||||
tokenW := width / factor
|
||||
|
||||
// Calculate previous grid dimensions
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(sqrt(ratio) * 16)
|
||||
prevTokenW := int32(sqrt(1.0/ratio) * 16)
|
||||
|
||||
// Encode the prompt text
|
||||
promptTokens := t.Encode(prompt)
|
||||
|
||||
// Build the full sequence:
|
||||
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
|
||||
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
|
||||
var tokens []int32
|
||||
tokens = append(tokens, promptTokens...)
|
||||
|
||||
// First grid: <sop> H W <eop>
|
||||
// First number has no space prefix, second number has space prefix (Ġ)
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(tokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// Second grid: <sop> prevH prevW <eop>
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// BOS token (start of image generation)
|
||||
tokens = append(tokens, t.BosTokenID)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
|
||||
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
// First try: look up the whole number as a single token
|
||||
if id, ok := t.Vocab[s]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
// Fallback: encode digit by digit
|
||||
var tokens []int32
|
||||
for _, c := range s {
|
||||
if id, ok := t.Vocab[string(c)]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
|
||||
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
|
||||
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
|
||||
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
|
||||
spaceToken := "Ġ" + s
|
||||
if id, ok := t.Vocab[spaceToken]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
|
||||
// Fallback: bare space Ġ + number tokens
|
||||
var tokens []int32
|
||||
if spaceID, ok := t.Vocab["Ġ"]; ok {
|
||||
tokens = append(tokens, spaceID)
|
||||
}
|
||||
tokens = append(tokens, t.encodeNumber(n)...)
|
||||
return tokens
|
||||
}
|
||||
|
||||
// sqrt is a helper for float64 sqrt
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
// Newton's method
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = z - (z*z-x)/(2*z)
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *GLMTokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if token, ok := t.VocabReverse[id]; ok {
|
||||
// Handle Ġ prefix (convert back to space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(token[len("Ġ"):])
|
||||
} else {
|
||||
sb.WriteString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
159
x/imagegen/models/glm_image/scheduler.go
Normal file
159
x/imagegen/models/glm_image/scheduler.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.25
|
||||
MaxShift float32 `json:"max_shift"` // 0.75
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns the default config for GLM-Image
|
||||
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.25,
|
||||
MaxShift: 0.75,
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 4096,
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "linear",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
|
||||
Sigmas []float32 // Shifted sigmas for Euler step calculation
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{Config: cfg}
|
||||
}
|
||||
|
||||
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
|
||||
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate shift (mu) based on image sequence length
|
||||
mu := s.calculateShift(imgSeqLen)
|
||||
|
||||
// Create timesteps: linspace from sigma_max_t to sigma_min_t
|
||||
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
|
||||
// Then apply time shift and append terminal sigma=0
|
||||
s.Timesteps = make([]float32, numSteps)
|
||||
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
|
||||
|
||||
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
|
||||
|
||||
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
|
||||
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
|
||||
s.Timesteps[i] = tRaw
|
||||
|
||||
// Convert to sigma [0, 1]
|
||||
sigma := tRaw / numTrainTimesteps
|
||||
|
||||
// Apply time shift if enabled
|
||||
if s.Config.UseDynamicShifting && mu > 0 {
|
||||
sigma = s.applyShift(mu, sigma)
|
||||
}
|
||||
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
|
||||
// Append terminal sigma = 0 (the final clean image)
|
||||
s.Sigmas[numSteps] = 0
|
||||
}
|
||||
|
||||
// calculateShift computes dynamic shift based on image sequence length
|
||||
// Uses the sqrt-based formula from diffusers:
|
||||
// m = (image_seq_len / base_seq_len) ** 0.5
|
||||
// mu = m * max_shift + base_shift
|
||||
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
|
||||
cfg := s.Config
|
||||
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
mu := m*cfg.MaxShift + cfg.BaseShift
|
||||
return mu
|
||||
}
|
||||
|
||||
// applyShift applies time shift transformation
|
||||
// mu: the computed shift value
|
||||
// t: sigma value in [0, 1]
|
||||
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if t >= 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// sigma=1.0 for both shift types
|
||||
sigma := float32(1.0)
|
||||
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
// Linear: mu / (mu + (1/t - 1)^sigma)
|
||||
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[stepIdx]
|
||||
sigmaNext := s.Sigmas[stepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + dt * v_t
|
||||
dt := sigmaNext - sigma // Negative (going from noise to clean)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep (for img2img)
|
||||
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
|
||||
// Use sigmas (shifted) for the interpolation
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
oneMinusSigma := 1.0 - sigma
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
|
||||
scaledNoise := mlx.MulScalar(noise, sigma)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps
|
||||
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// T5Config holds T5 encoder configuration
|
||||
type T5Config struct {
|
||||
DModel int32 `json:"d_model"` // 1472
|
||||
DFF int32 `json:"d_ff"` // 3584
|
||||
DKV int32 `json:"d_kv"` // 64
|
||||
NumHeads int32 `json:"num_heads"` // 6
|
||||
NumLayers int32 `json:"num_layers"` // 12
|
||||
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
|
||||
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
|
||||
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
|
||||
|
||||
// Relative position bias
|
||||
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
|
||||
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
|
||||
}
|
||||
|
||||
// T5TextEncoder is the T5 encoder for text conditioning
|
||||
type T5TextEncoder struct {
|
||||
Config *T5Config
|
||||
|
||||
// Embedding (shared for ByT5)
|
||||
SharedEmbed *nn.Embedding `weight:"shared"`
|
||||
|
||||
// Encoder layers
|
||||
Layers []*T5Block `weight:"encoder.block"`
|
||||
|
||||
// Final layer norm
|
||||
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
|
||||
|
||||
// Relative position bias (from first layer, shared across all)
|
||||
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
|
||||
}
|
||||
|
||||
// T5Block is a single T5 encoder block
|
||||
type T5Block struct {
|
||||
// Self attention
|
||||
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
|
||||
// FFN
|
||||
Layer1 *T5LayerFF `weight:"layer.1"`
|
||||
}
|
||||
|
||||
// T5LayerSelfAttention is T5's self-attention layer
|
||||
type T5LayerSelfAttention struct {
|
||||
SelfAttention *T5Attention `weight:"SelfAttention"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5Attention implements T5's relative attention
|
||||
type T5Attention struct {
|
||||
Q *mlx.Array `weight:"q.weight"` // No bias in T5
|
||||
K *mlx.Array `weight:"k.weight"`
|
||||
V *mlx.Array `weight:"v.weight"`
|
||||
O *mlx.Array `weight:"o.weight"`
|
||||
|
||||
NHeads int32
|
||||
DKV int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// T5LayerFF is T5's feedforward layer with gated-gelu
|
||||
type T5LayerFF struct {
|
||||
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5DenseGatedGelu is T5's gated-gelu FFN
|
||||
type T5DenseGatedGelu struct {
|
||||
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
|
||||
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
|
||||
Wo *mlx.Array `weight:"wo.weight"` // down projection
|
||||
}
|
||||
|
||||
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
|
||||
type T5LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Load loads the T5 text encoder from manifest
|
||||
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the T5 text encoder from a directory path
|
||||
func (m *T5TextEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *T5TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.LayerNormEps
|
||||
for _, block := range m.Layers {
|
||||
attn := block.Layer0.SelfAttention
|
||||
attn.NHeads = cfg.NumHeads
|
||||
attn.DKV = cfg.DKV
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
|
||||
|
||||
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
|
||||
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Get embeddings
|
||||
h := m.SharedEmbed.Forward(tokens)
|
||||
|
||||
// Compute relative position bias once
|
||||
seqLen := tokens.Shape()[1]
|
||||
posBias := m.computeRelativePositionBias(seqLen)
|
||||
|
||||
// Forward through layers
|
||||
for _, block := range m.Layers {
|
||||
h = block.Forward(h, posBias, cfg.LayerNormEps)
|
||||
}
|
||||
|
||||
// Final norm
|
||||
h = m.FinalNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
|
||||
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
|
||||
// Glyph texts are used for text rendering guidance in the generated image
|
||||
func extractGlyphTexts(prompt string) []string {
|
||||
var glyphTexts []string
|
||||
|
||||
// Extract text in single quotes: 'text'
|
||||
re1 := regexp.MustCompile(`'([^']*)'`)
|
||||
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Unicode curly double quotes: "text"
|
||||
re2 := regexp.MustCompile(`"([^""]*)"`)
|
||||
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in ASCII double quotes: "text"
|
||||
re3 := regexp.MustCompile(`"([^"]*)"`)
|
||||
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Japanese quotes: 「text」
|
||||
re4 := regexp.MustCompile(`「([^「」]*)」`)
|
||||
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
return glyphTexts
|
||||
}
|
||||
|
||||
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
|
||||
// This provides text conditioning for the diffusion transformer via the glyph projector
|
||||
//
|
||||
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
|
||||
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
|
||||
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
|
||||
// This matches diffusers' _get_glyph_embeds() behavior.
|
||||
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
|
||||
// Extract glyph texts from prompt (text in quotes)
|
||||
glyphTexts := extractGlyphTexts(prompt)
|
||||
|
||||
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
|
||||
if len(glyphTexts) == 0 {
|
||||
glyphTexts = []string{""}
|
||||
}
|
||||
|
||||
// Encode each glyph text and collect token sequences
|
||||
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
|
||||
var allTokenSeqs [][]int32
|
||||
|
||||
for _, glyphText := range glyphTexts {
|
||||
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
|
||||
tokens := tok.Encode(glyphText)
|
||||
|
||||
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
|
||||
tokens = append(tokens, tok.EOSTokenID)
|
||||
|
||||
allTokenSeqs = append(allTokenSeqs, tokens)
|
||||
}
|
||||
|
||||
// Process each glyph text through the encoder
|
||||
var allEmbeddings []*mlx.Array
|
||||
for _, tokens := range allTokenSeqs {
|
||||
tokenLen := len(tokens)
|
||||
if tokenLen == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token array [1, L]
|
||||
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
|
||||
|
||||
// Forward through encoder
|
||||
output := m.Forward(tokensArr)
|
||||
mlx.Eval(output)
|
||||
|
||||
allEmbeddings = append(allEmbeddings, output)
|
||||
}
|
||||
|
||||
// Concatenate all glyph embeddings along sequence dimension
|
||||
var output *mlx.Array
|
||||
if len(allEmbeddings) == 0 {
|
||||
// Fallback: return single zero embedding
|
||||
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
|
||||
} else if len(allEmbeddings) == 1 {
|
||||
output = allEmbeddings[0]
|
||||
} else {
|
||||
output = mlx.Concatenate(allEmbeddings, 1)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// computeRelativePositionBias computes T5's relative position encoding
|
||||
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create relative position matrix
|
||||
// For each (query_pos, key_pos) pair, compute bucketed relative position
|
||||
numBuckets := cfg.RelativeAttentionNumBuckets
|
||||
maxDistance := cfg.RelativeAttentionMaxDistance
|
||||
|
||||
// Create position indices
|
||||
contextPos := make([]int32, seqLen*seqLen)
|
||||
memoryPos := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
for j := int32(0); j < seqLen; j++ {
|
||||
contextPos[i*seqLen+j] = i
|
||||
memoryPos[i*seqLen+j] = j
|
||||
}
|
||||
}
|
||||
|
||||
// Compute relative positions and bucket them
|
||||
buckets := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen*seqLen; i++ {
|
||||
relPos := memoryPos[i] - contextPos[i]
|
||||
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
|
||||
}
|
||||
|
||||
// Create bucket indices array
|
||||
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
|
||||
|
||||
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
|
||||
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
|
||||
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
|
||||
|
||||
// Transpose to [numHeads, seqLen, seqLen]
|
||||
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
|
||||
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
|
||||
|
||||
return bias
|
||||
}
|
||||
|
||||
// relativePosistionBucket computes the bucket for a relative position
|
||||
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
|
||||
var bucket int32 = 0
|
||||
var n int32 = -relativePosition
|
||||
|
||||
if bidirectional {
|
||||
numBuckets /= 2
|
||||
if n < 0 {
|
||||
bucket += numBuckets
|
||||
n = -n
|
||||
}
|
||||
} else {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Half buckets are for exact positions, half are for log-spaced
|
||||
maxExact := numBuckets / 2
|
||||
if n < maxExact {
|
||||
bucket += n
|
||||
} else {
|
||||
// Log-spaced buckets
|
||||
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
|
||||
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
|
||||
if bucket > numBuckets-1 {
|
||||
bucket = numBuckets - 1
|
||||
}
|
||||
}
|
||||
|
||||
return bucket
|
||||
}
|
||||
|
||||
// Forward for T5Block
|
||||
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Self attention with residual
|
||||
h := b.Layer0.Forward(x, posBias, eps)
|
||||
|
||||
// FFN with residual
|
||||
h = b.Layer1.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward for T5LayerSelfAttention
|
||||
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// Attention
|
||||
attnOut := l.SelfAttention.Forward(normed, posBias)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, attnOut)
|
||||
}
|
||||
|
||||
// Forward for T5Attention
|
||||
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Q, K, V projections (no bias)
|
||||
// Weights are [out_features, in_features], so we use matmul with transpose
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
|
||||
|
||||
// Reshape to [B, L, nheads, d_kv]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
|
||||
|
||||
// Transpose to [B, nheads, L, d_kv]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Attention scores with relative position bias
|
||||
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
|
||||
// (no 1/sqrt(d_k) scale factor like in standard transformers)
|
||||
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
|
||||
scores = mlx.Add(scores, posBias)
|
||||
|
||||
// Softmax
|
||||
attnWeights := mlx.Softmax(scores, -1)
|
||||
|
||||
// Attend to values
|
||||
out := mlx.Matmul(attnWeights, v)
|
||||
|
||||
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, D]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
|
||||
|
||||
_ = D // Silence unused warning
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for T5LayerFF
|
||||
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// FFN
|
||||
ffOut := l.DenseReluDense.Forward(normed)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, ffOut)
|
||||
}
|
||||
|
||||
// geluNew implements the GELU activation with tanh approximation (gelu_new)
|
||||
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
|
||||
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
func geluNew(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
|
||||
coeff := float32(0.044715)
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// Forward for T5DenseGatedGelu (gated-gelu activation)
|
||||
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
|
||||
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
|
||||
gate = geluNew(gate)
|
||||
|
||||
// Up projection
|
||||
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
|
||||
|
||||
// Gated output
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
|
||||
}
|
||||
|
||||
// Forward for T5LayerNorm (RMSNorm variant)
|
||||
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
|
||||
variance := mlx.Mean(mlx.Square(x), -1, true)
|
||||
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
|
||||
return mlx.Mul(x, ln.Weight)
|
||||
}
|
||||
1255
x/imagegen/models/glm_image/transformer.go
Normal file
1255
x/imagegen/models/glm_image/transformer.go
Normal file
File diff suppressed because it is too large
Load Diff
477
x/imagegen/models/glm_image/vae.go
Normal file
477
x/imagegen/models/glm_image/vae.go
Normal file
@@ -0,0 +1,477 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 16
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 3
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
|
||||
ShiftFactor *float32 `json:"shift_factor"` // null
|
||||
LatentsMean []float32 `json:"latents_mean"` // [16 values]
|
||||
LatentsStd []float32 `json:"latents_std"` // [16 values]
|
||||
}
|
||||
|
||||
// VAEDecoder is the VAE latent decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Decoder components
|
||||
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
|
||||
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
|
||||
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
|
||||
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
|
||||
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
|
||||
}
|
||||
|
||||
// VAEConv2d is a 2D convolution layer
|
||||
type VAEConv2d struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// GroupNorm is group normalization
|
||||
type GroupNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block of the VAE
|
||||
type VAEMidBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
}
|
||||
|
||||
// VAEUpBlock is an upsampling block
|
||||
type VAEUpBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
|
||||
}
|
||||
|
||||
// VAEResnetBlock is a residual block
|
||||
type VAEResnetBlock struct {
|
||||
Norm1 *GroupNorm `weight:"norm1"`
|
||||
Conv1 *VAEConv2d `weight:"conv1"`
|
||||
Norm2 *GroupNorm `weight:"norm2"`
|
||||
Conv2 *VAEConv2d `weight:"conv2"`
|
||||
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
|
||||
}
|
||||
|
||||
// VAEUpsampler is an upsampling layer
|
||||
type VAEUpsampler struct {
|
||||
Conv *VAEConv2d `weight:"conv"`
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from manifest
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
|
||||
// And all but the last up_block has an upsampler
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
// All but the last block has upsamplers
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the VAE decoder from a directory path
|
||||
func (m *VAEDecoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VAEDecoder) initGroupNorms() {
|
||||
cfg := m.Config
|
||||
numGroups := cfg.NormNumGroups
|
||||
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
|
||||
|
||||
if m.ConvNormOut != nil {
|
||||
m.ConvNormOut.NumGroups = numGroups
|
||||
m.ConvNormOut.Eps = eps
|
||||
}
|
||||
|
||||
if m.MidBlock != nil {
|
||||
for _, resnet := range m.MidBlock.Resnets {
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, upBlock := range m.UpBlocks {
|
||||
if upBlock == nil {
|
||||
continue
|
||||
}
|
||||
for _, resnet := range upBlock.Resnets {
|
||||
if resnet == nil {
|
||||
continue
|
||||
}
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes latents to an image
|
||||
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Apply latent denormalization if mean/std are provided
|
||||
// This matches diffusers GLM-Image: latents = latents * std + mean
|
||||
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
|
||||
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
|
||||
latents = m.denormalizeLatents(latents)
|
||||
}
|
||||
|
||||
// Convert from NCHW to NHWC for processing
|
||||
// [B, C, H, W] -> [B, H, W, C]
|
||||
x := mlx.Transpose(latents, 0, 2, 3, 1)
|
||||
|
||||
// Initial convolution
|
||||
x = m.ConvIn.Forward(x)
|
||||
|
||||
// Mid block
|
||||
x = m.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
|
||||
for i := 0; i < len(m.UpBlocks); i++ {
|
||||
if m.UpBlocks[i] != nil {
|
||||
x = m.UpBlocks[i].Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Final normalization and convolution
|
||||
x = m.ConvNormOut.Forward(x)
|
||||
x = mlx.SiLU(x)
|
||||
x = m.ConvOut.Forward(x)
|
||||
|
||||
// Convert back to NCHW
|
||||
// [B, H, W, C] -> [B, C, H, W]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 2)
|
||||
|
||||
// Clamp to valid range and convert to [0, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.AddScalar(x, 1.0)
|
||||
x = mlx.DivScalar(x, 2.0)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// denormalizeLatents applies the latent mean/std denormalization
|
||||
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create mean and std arrays [1, C, 1, 1] for broadcasting
|
||||
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
|
||||
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
|
||||
|
||||
// Denormalize: latents * std + mean
|
||||
latents = mlx.Mul(latents, std)
|
||||
latents = mlx.Add(latents, mean)
|
||||
|
||||
return latents
|
||||
}
|
||||
|
||||
// Forward for VAEConv2d
|
||||
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C_in] (NHWC)
|
||||
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
|
||||
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
|
||||
// So we need to transpose from OIHW to OHWI
|
||||
|
||||
stride := c.Stride
|
||||
if stride == 0 {
|
||||
stride = 1
|
||||
}
|
||||
padding := c.Padding
|
||||
if padding == 0 {
|
||||
// Default to same padding for 3x3 kernels
|
||||
wShape := c.Weight.Shape()
|
||||
if len(wShape) >= 3 && wShape[2] == 3 {
|
||||
padding = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
|
||||
|
||||
out := mlx.Conv2d(x, weight, stride, padding)
|
||||
if c.Bias != nil {
|
||||
// Bias: [C_out] -> [1, 1, 1, C_out]
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for GroupNorm
|
||||
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C] (NHWC)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
numGroups := gn.NumGroups
|
||||
if numGroups == 0 {
|
||||
numGroups = 32
|
||||
}
|
||||
groupSize := C / numGroups
|
||||
|
||||
// Reshape to [B, H, W, groups, groupSize]
|
||||
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Forward for VAEMidBlock
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range mb.Resnets {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEUpBlock
|
||||
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Apply resnets
|
||||
for _, resnet := range ub.Resnets {
|
||||
if resnet != nil {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply upsamplers
|
||||
for _, upsampler := range ub.Upsamplers {
|
||||
if upsampler != nil {
|
||||
x = upsampler.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEResnetBlock
|
||||
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
|
||||
// First norm + activation + conv
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
// Second norm + activation + conv
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
// Shortcut for channel mismatch
|
||||
if rb.ConvShortcut != nil {
|
||||
residual = rb.ConvShortcut.Forward(residual)
|
||||
}
|
||||
|
||||
return mlx.Add(h, residual)
|
||||
}
|
||||
|
||||
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
|
||||
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C]
|
||||
// 2x nearest neighbor upsample
|
||||
x = upsample2x(x)
|
||||
|
||||
// Conv
|
||||
if us.Conv != nil {
|
||||
x = us.Conv.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
hIndices := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
hIndices[i*2] = i
|
||||
hIndices[i*2+1] = i
|
||||
}
|
||||
wIndices := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
wIndices[i*2] = i
|
||||
wIndices[i*2+1] = i
|
||||
}
|
||||
|
||||
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
|
||||
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
|
||||
|
||||
// Take along height axis
|
||||
x = mlx.Reshape(x, B*H, W, C)
|
||||
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
|
||||
x = mlx.Reshape(x, B, H, W*2, C)
|
||||
|
||||
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
|
||||
x = mlx.Reshape(x, B*(W*2), H, C)
|
||||
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
|
||||
x = mlx.Reshape(x, B, W*2, H*2, C)
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
|
||||
|
||||
return x
|
||||
}
|
||||
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
@@ -0,0 +1,982 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VisionLanguageConfig holds GLM-Image AR generator configuration
|
||||
type VisionLanguageConfig struct {
|
||||
// Text model config
|
||||
HiddenSize int32 `json:"hidden_size"` // 4096
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
|
||||
IntermediateSize int32 `json:"intermediate_size"` // 13696
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
|
||||
VocabSize int32 `json:"vocab_size"` // 168064
|
||||
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
|
||||
|
||||
// RoPE config
|
||||
RopeTheta float32 `json:"rope_theta"` // 10000
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
|
||||
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
|
||||
|
||||
// Visual token config
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
|
||||
ImageTokenID int32 `json:"image_token_id"` // 167855
|
||||
|
||||
// Computed
|
||||
HeadDim int32
|
||||
}
|
||||
|
||||
// VisionLanguageEncoder is the 9B AR generator
|
||||
type VisionLanguageEncoder struct {
|
||||
Config *VisionLanguageConfig
|
||||
|
||||
// Embedding
|
||||
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
|
||||
|
||||
// Transformer layers
|
||||
Layers []*GLMBlock `weight:"model.language_model.layers"`
|
||||
|
||||
// Final norm
|
||||
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
|
||||
|
||||
// LM Head
|
||||
LMHead *mlx.Array `weight:"lm_head.weight"`
|
||||
}
|
||||
|
||||
// GLMBlock is a single transformer block in GLM-4 style
|
||||
type GLMBlock struct {
|
||||
// Pre-attention norm (GLM uses post-LN variant)
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
|
||||
|
||||
// Attention
|
||||
SelfAttn *GLMAttention `weight:"self_attn"`
|
||||
|
||||
// MLP (fused gate_up)
|
||||
MLP *GLMMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// GLMAttention implements GQA with partial rotary and MRoPE
|
||||
type GLMAttention struct {
|
||||
QProj *mlx.Array `weight:"q_proj.weight"`
|
||||
KProj *mlx.Array `weight:"k_proj.weight"`
|
||||
VProj *mlx.Array `weight:"v_proj.weight"`
|
||||
OProj *mlx.Array `weight:"o_proj.weight"`
|
||||
|
||||
// QKV have biases in GLM
|
||||
QBias *mlx.Array `weight:"q_proj.bias"`
|
||||
KBias *mlx.Array `weight:"k_proj.bias"`
|
||||
VBias *mlx.Array `weight:"v_proj.bias"`
|
||||
|
||||
// Computed
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
PartialRotary float32 // Only rotate this fraction of head_dim
|
||||
RopeTheta float32
|
||||
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
|
||||
}
|
||||
|
||||
// ARCache holds KV caches for all layers using the shared cache implementation
|
||||
type ARCache struct {
|
||||
Layers []cache.Cache
|
||||
}
|
||||
|
||||
// NewARCache creates a new cache for the given number of layers
|
||||
func NewARCache(numLayers int32) *ARCache {
|
||||
layers := make([]cache.Cache, numLayers)
|
||||
for i := range layers {
|
||||
layers[i] = cache.NewKVCache()
|
||||
}
|
||||
return &ARCache{Layers: layers}
|
||||
}
|
||||
|
||||
// Free releases all cached tensors
|
||||
func (c *ARCache) Free() {
|
||||
for _, layer := range c.Layers {
|
||||
for _, arr := range layer.State() {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GLMMLP implements fused gate_up SwiGLU MLP
|
||||
type GLMMLP struct {
|
||||
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
|
||||
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
|
||||
DownProj *mlx.Array `weight:"down_proj.weight"`
|
||||
}
|
||||
|
||||
// Load loads the vision-language encoder from manifest
|
||||
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the vision-language encoder from a directory path
|
||||
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &rawCfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VisionLanguageEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
for _, block := range m.Layers {
|
||||
block.SelfAttn.NHeads = cfg.NumAttentionHeads
|
||||
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.SelfAttn.HeadDim = cfg.HeadDim
|
||||
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
|
||||
block.SelfAttn.RopeTheta = cfg.RopeTheta
|
||||
block.SelfAttn.MRoPESection = cfg.MRoPESection
|
||||
|
||||
// Set norm eps
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
// Generate autoregressively generates visual tokens with KV caching
|
||||
func (m *VisionLanguageEncoder) Generate(
|
||||
prompt string,
|
||||
tok *GLMTokenizer,
|
||||
maxTokens int32,
|
||||
temperature float32,
|
||||
topP float32,
|
||||
seed int64,
|
||||
targetHeight, targetWidth int32,
|
||||
progressFn func(int),
|
||||
) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Encode prompt with grid tokens using GLM tokenizer
|
||||
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
|
||||
|
||||
// Calculate grid dimensions for MRoPE position IDs
|
||||
factor := int32(32)
|
||||
tokenH := targetHeight / factor
|
||||
tokenW := targetWidth / factor
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
|
||||
prevGridSize := prevTokenH * prevTokenW
|
||||
|
||||
// Create KV cache for all layers
|
||||
cache := NewARCache(cfg.NumHiddenLayers)
|
||||
defer cache.Free()
|
||||
|
||||
// ===== PREFILL PHASE =====
|
||||
// Process entire prompt at once, populate cache
|
||||
promptLen := int32(len(tokens))
|
||||
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
mlx.Eval(h)
|
||||
|
||||
// Compute position IDs for prefill (text tokens use same position for all dims)
|
||||
prefillPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
prefillPositions[dim] = make([]int32, promptLen)
|
||||
for i := int32(0); i < promptLen; i++ {
|
||||
prefillPositions[dim][i] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Forward through layers (prefill)
|
||||
for i, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
|
||||
if i > 0 {
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
// Eval h and cache arrays together so cache is materialized
|
||||
evalArgs := []*mlx.Array{h}
|
||||
for _, lc := range cache.Layers {
|
||||
evalArgs = append(evalArgs, lc.State()...)
|
||||
}
|
||||
mlx.Eval(evalArgs...)
|
||||
|
||||
// Final norm and get logits for last position
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
|
||||
h.Free()
|
||||
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
|
||||
lastH.Free()
|
||||
|
||||
// Sample first token
|
||||
var sampleCounter int64 = 0
|
||||
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
// AR generation loop with caching
|
||||
// Visual tokens are stored as VQ codebook indices [0, 16383]
|
||||
// The LM head outputs indices [0, 16511] where:
|
||||
// - [0, 16383] are VQ codes
|
||||
// - 16384 is BOS
|
||||
// - 16385 is EOS
|
||||
visualTokens := make([]int32, 0, maxTokens)
|
||||
posOffset := promptLen
|
||||
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
|
||||
|
||||
// Preallocate slice for old cache state to reuse
|
||||
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
|
||||
for i := int32(0); i < maxTokens; i++ {
|
||||
if progressFn != nil {
|
||||
progressFn(int(i))
|
||||
}
|
||||
|
||||
// Check for end token (EOS = 16385)
|
||||
if nextToken == cfg.ImageEndTokenID {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
|
||||
if nextToken == cfg.ImageStartTokenID {
|
||||
// BOS token - skip storing but continue generation
|
||||
} else if nextToken < cfg.ImageStartTokenID {
|
||||
// This is an actual VQ code [0, 16383] - store it
|
||||
visualTokens = append(visualTokens, nextToken)
|
||||
}
|
||||
// Tokens >= 16386 are other special tokens, skip them
|
||||
|
||||
// ===== DECODE PHASE =====
|
||||
// Save old cache state before forward (to free after eval)
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, lc := range cache.Layers {
|
||||
oldCacheState = append(oldCacheState, lc.State()...)
|
||||
}
|
||||
|
||||
// Only process the new token, use cached K,V
|
||||
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
// Compute MRoPE position IDs for this visual token
|
||||
// Visual tokens are arranged in two grids: prev grid then target grid
|
||||
// Position dimensions: [temporal, height, width]
|
||||
decodePositions := computeVisualTokenPositions(
|
||||
visualTokenIdx, posOffset, promptLen,
|
||||
prevTokenH, prevTokenW, prevGridSize,
|
||||
tokenH, tokenW,
|
||||
)
|
||||
|
||||
// Forward through layers (decode with cache)
|
||||
for j, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
|
||||
if j > 0 { // Don't free the embedding on first layer
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Eval h and new cache state
|
||||
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
for _, lc := range cache.Layers {
|
||||
newCacheState = append(newCacheState, lc.State()...)
|
||||
}
|
||||
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
|
||||
|
||||
// Free old cache state (now that new state is evaluated)
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
// Get logits (h is already [1, 1, hidden_size])
|
||||
h = mlx.Reshape(h, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
|
||||
h.Free()
|
||||
|
||||
// Sample next token
|
||||
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
posOffset++
|
||||
visualTokenIdx++
|
||||
|
||||
// Periodically clear cache to release intermediate memory
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
if len(visualTokens) == 0 {
|
||||
// Return at least one token to avoid empty tensor issues
|
||||
visualTokens = append(visualTokens, 0)
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
|
||||
}
|
||||
|
||||
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
|
||||
// Returns [3][1] position IDs for temporal, height, and width dimensions
|
||||
//
|
||||
// MRoPE position encoding for GLM-Image visual tokens:
|
||||
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
|
||||
// - height: decode_pos + row index within grid
|
||||
// - width: decode_pos + column index within grid
|
||||
//
|
||||
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
|
||||
// sufficient positional separation.
|
||||
func computeVisualTokenPositions(
|
||||
visualIdx int32, absPos int32, promptLen int32,
|
||||
prevH, prevW, prevSize int32,
|
||||
targetH, targetW int32,
|
||||
) [][]int32 {
|
||||
positions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
positions[dim] = make([]int32, 1)
|
||||
}
|
||||
|
||||
// First grid (prev grid) starts at decode_pos = promptLen
|
||||
prevGridDecodePos := promptLen
|
||||
|
||||
// Second grid (target grid) starts after first grid
|
||||
// next_pos = prev_decode_pos + max(prevH, prevW)
|
||||
maxPrev := prevH
|
||||
if prevW > maxPrev {
|
||||
maxPrev = prevW
|
||||
}
|
||||
targetGridDecodePos := prevGridDecodePos + maxPrev
|
||||
|
||||
// Compute position IDs based on which grid the token is in
|
||||
if visualIdx < prevSize {
|
||||
// Token is in the prev grid (prev_token_h × prev_token_w)
|
||||
row := visualIdx / prevW
|
||||
col := visualIdx % prevW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = prevGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = prevGridDecodePos + row
|
||||
positions[2][0] = prevGridDecodePos + col
|
||||
} else {
|
||||
// Token is in the target grid (token_h × token_w)
|
||||
targetIdx := visualIdx - prevSize
|
||||
row := targetIdx / targetW
|
||||
col := targetIdx % targetW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = targetGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = targetGridDecodePos + row
|
||||
positions[2][0] = targetGridDecodePos + col
|
||||
}
|
||||
|
||||
_ = targetH // Used for documentation clarity
|
||||
_ = absPos // No longer used - kept for API compatibility
|
||||
return positions
|
||||
}
|
||||
|
||||
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
|
||||
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
|
||||
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
|
||||
// sampleCounter is incremented for each call to ensure different random values
|
||||
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
|
||||
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
|
||||
// Output index directly corresponds to vocab ID [0, 16511]
|
||||
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
|
||||
visualLogits := logits
|
||||
|
||||
// Apply temperature
|
||||
if temperature != 1.0 && temperature > 0 {
|
||||
visualLogits = mlx.DivScalar(visualLogits, temperature)
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
probs := mlx.Softmax(visualLogits, -1)
|
||||
mlx.Eval(probs)
|
||||
|
||||
// Get the sampled index using top-p sampling
|
||||
// This directly gives us the vocab ID in [0, 16511]
|
||||
// Special tokens: 16384 = BOS, 16385 = EOS
|
||||
// Use seed + counter for reproducible but different random values
|
||||
effectiveSeed := seed + *sampleCounter
|
||||
*sampleCounter++
|
||||
return sampleTopP(probs, topP, effectiveSeed)
|
||||
}
|
||||
|
||||
// sampleTopP implements nucleus (top-p) sampling
|
||||
// probs: [1, vocab_size] probability distribution
|
||||
// topP: cumulative probability threshold (e.g., 0.75)
|
||||
// seed: random seed for reproducible sampling
|
||||
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
|
||||
// Negate probs for descending sort (Argsort only does ascending)
|
||||
negProbs := mlx.MulScalar(probs, -1)
|
||||
sortedIndices := mlx.Argsort(negProbs, -1)
|
||||
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
|
||||
cumProbs := mlx.Cumsum(sortedProbs, -1)
|
||||
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
|
||||
|
||||
// Find cutoff index where cumulative probability exceeds topP
|
||||
probsData := sortedProbs.Data()
|
||||
cumProbsData := cumProbs.Data()
|
||||
indicesData := sortedIndices.DataInt32()
|
||||
|
||||
// Calculate cutoff and renormalize
|
||||
var cutoffIdx int
|
||||
var totalProb float32
|
||||
for i, cp := range cumProbsData {
|
||||
totalProb += probsData[i]
|
||||
if cp >= topP {
|
||||
cutoffIdx = i + 1 // Include this token
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutoffIdx == 0 {
|
||||
cutoffIdx = len(probsData) // Use all tokens if topP is very high
|
||||
}
|
||||
|
||||
// Sample from the truncated distribution
|
||||
// Renormalize the truncated probabilities
|
||||
truncatedProbs := make([]float32, cutoffIdx)
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
truncatedProbs[i] = probsData[i] / totalProb
|
||||
}
|
||||
|
||||
// Sample using random number with provided seed for reproducibility
|
||||
r := mlx.RandomUniform([]int32{1}, uint64(seed))
|
||||
mlx.Eval(r)
|
||||
randVal := r.Data()[0]
|
||||
|
||||
// Find the sampled token
|
||||
var cumulative float32
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
cumulative += truncatedProbs[i]
|
||||
if randVal < cumulative {
|
||||
return indicesData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the last token in truncated set
|
||||
return indicesData[cutoffIdx-1]
|
||||
}
|
||||
|
||||
// Forward for GLMBlock
|
||||
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
|
||||
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs block forward with optional KV caching and MRoPE
|
||||
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
|
||||
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
// Pre-attention norm
|
||||
normed := b.InputLayerNorm.Forward(x, eps)
|
||||
|
||||
// Self-attention with RoPE/MRoPE and cache
|
||||
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
|
||||
|
||||
// Post-attention norm (GLM-4 style)
|
||||
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
// Post-attention layer norm
|
||||
normed = b.PostAttnLayerNorm.Forward(x, eps)
|
||||
|
||||
// MLP
|
||||
mlpOut := b.MLP.Forward(normed)
|
||||
|
||||
// Post-MLP norm
|
||||
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for GLMAttention (without cache - used for prefill)
|
||||
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
|
||||
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs attention with optional KV caching and MRoPE
|
||||
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
|
||||
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
|
||||
// kvcache is updated in-place if provided
|
||||
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// Q, K, V projections
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
|
||||
|
||||
// Add biases
|
||||
if attn.QBias != nil {
|
||||
q = mlx.Add(q, attn.QBias)
|
||||
}
|
||||
if attn.KBias != nil {
|
||||
k = mlx.Add(k, attn.KBias)
|
||||
}
|
||||
if attn.VBias != nil {
|
||||
v = mlx.Add(v, attn.VBias)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// Apply partial RoPE or MRoPE
|
||||
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
|
||||
if len(attn.MRoPESection) == 3 && positionIDs != nil {
|
||||
// Use MRoPE with explicit position IDs
|
||||
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else if len(attn.MRoPESection) == 3 {
|
||||
// Use MRoPE with sequential positions (same for all dims - text mode)
|
||||
seqPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
seqPositions[dim] = make([]int32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
seqPositions[dim][i] = i + posOffset
|
||||
}
|
||||
}
|
||||
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else {
|
||||
// Fallback to standard RoPE
|
||||
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Update cache and get full K, V for attention
|
||||
if kvcache != nil {
|
||||
k, v = kvcache.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
// Repeat KV for GQA
|
||||
kExpanded := k
|
||||
vExpanded := v
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
kExpanded = repeatKV(k, repeats)
|
||||
vExpanded = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention with causal mask
|
||||
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
|
||||
|
||||
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, hidden_size]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
|
||||
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
|
||||
}
|
||||
|
||||
// applyPartialRoPEWithOffset applies RoPE with a position offset
|
||||
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply RoPE to rotary part with position offset
|
||||
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
|
||||
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
|
||||
// mrope_section: [8, 12, 12] - frequency pairs per dimension
|
||||
// For text tokens: all 3 dimensions have the same sequential position
|
||||
// For image tokens: temporal=seq_idx, height=row, width=col
|
||||
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply MRoPE to rotary part
|
||||
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyMRoPE applies multi-dimensional rotary position embedding
|
||||
// x: [B, L, H, D] where D is the rotary dimension
|
||||
// positionIDs: [3][L] - positions for temporal, height, width dimensions
|
||||
// mropeSection: [8, 12, 12] - frequency pairs per dimension
|
||||
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Validate mrope_section sums to half (number of frequency pairs)
|
||||
var totalPairs int32
|
||||
for _, s := range mropeSection {
|
||||
totalPairs += s
|
||||
}
|
||||
if totalPairs != half {
|
||||
// Fallback to standard RoPE if section doesn't match
|
||||
return applyRoPEWithOffset(x, L, 0, theta)
|
||||
}
|
||||
|
||||
// Build angles for each position dimension (matching Python's MRoPE approach)
|
||||
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
|
||||
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
|
||||
angleVals := make([]*mlx.Array, 3)
|
||||
|
||||
freqOffset := int32(0)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
numPairs := mropeSection[dim]
|
||||
if numPairs == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute inverse frequencies for this section
|
||||
// Each dimension uses DIFFERENT frequency ranges:
|
||||
// - Temporal: frequencies 0 to section[0]-1
|
||||
// - Height: frequencies section[0] to section[0]+section[1]-1
|
||||
// - Width: frequencies section[0]+section[1] to sum(section)-1
|
||||
freqsArr := make([]float32, numPairs)
|
||||
for i := int32(0); i < numPairs; i++ {
|
||||
globalIdx := freqOffset + i
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
|
||||
|
||||
// Position indices for this dimension
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(positionIDs[dim][i])
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, numPairs] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
|
||||
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
freqOffset += numPairs
|
||||
}
|
||||
|
||||
// Concatenate all sections: [L, half] = [L, 32]
|
||||
allAngles := mlx.Concatenate(angleVals, 1)
|
||||
|
||||
// Duplicate AFTER concatenation: [L, D] = [L, 64]
|
||||
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
|
||||
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
|
||||
|
||||
// Compute cos/sin
|
||||
allCos := mlx.Cos(allAngles)
|
||||
allSin := mlx.Sin(allAngles)
|
||||
|
||||
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
|
||||
allCos = mlx.Reshape(allCos, 1, L, 1, D)
|
||||
allSin = mlx.Reshape(allSin, 1, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1)
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary position embedding
|
||||
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
return applyRoPEWithOffset(x, seqLen, 0, theta)
|
||||
}
|
||||
|
||||
// applyRoPEWithOffset applies rotary position embedding with position offset
|
||||
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
|
||||
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Compute inverse frequencies: 1 / (theta^(2i/d))
|
||||
freqsArr := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
// Position indices with offset
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(i + posOffset)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, half] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
angles := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
|
||||
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
|
||||
|
||||
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
|
||||
cosVals := mlx.Cos(anglesDup)
|
||||
sinVals := mlx.Sin(anglesDup)
|
||||
cosVals = mlx.Reshape(cosVals, L, 1, D)
|
||||
sinVals = mlx.Reshape(sinVals, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// x: [B, nkvheads, L, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// x: [B, nkvheads, 1, L, head_dim]
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
// x: [B, nkvheads, repeats, L, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Forward for GLMMLP (fused gate_up SwiGLU)
|
||||
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
// gate_up_proj outputs [gate, up] concatenated
|
||||
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
|
||||
|
||||
shape := gateUp.Shape()
|
||||
halfDim := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
|
||||
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
gate = mlx.SiLU(gate)
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
|
||||
}
|
||||
@@ -222,6 +222,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
@@ -264,10 +272,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
@@ -305,6 +322,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
|
||||
@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
@@ -128,14 +128,9 @@ func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timest
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
// InitNoise creates initial noise for sampling (BFloat16 for GPU efficiency)
|
||||
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return RandomNormal(shape, seed)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal tensor using MLX
|
||||
func RandomNormal(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -28,27 +26,14 @@ type Qwen3Config struct {
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// loadQwen3Config loads text encoder config from a JSON file
|
||||
func loadQwen3Config(path string) (*Qwen3Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
@@ -151,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
@@ -194,33 +179,44 @@ type Qwen3TextEncoder struct {
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from a directory
|
||||
func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen3 text encoder...")
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load config from blob
|
||||
var cfg Qwen3Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = cfg
|
||||
|
||||
// Pre-allocate layers slice
|
||||
m.Qwen3Config = &cfg
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize computed fields
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Qwen3TextEncoder) initComputedFields() {
|
||||
cfg := m.Qwen3Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
@@ -235,9 +231,6 @@ func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
@@ -38,8 +36,8 @@ type TransformerConfig struct {
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
@@ -76,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
Linear nn.LinearLayer `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
@@ -88,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
@@ -102,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -117,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
@@ -128,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
Dim int32 `weight:"-"`
|
||||
Scale float32 `weight:"-"`
|
||||
}
|
||||
|
||||
// FuseQKV creates a fused QKV projection by concatenating weights.
|
||||
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
|
||||
// Note: Fusion is skipped for quantized weights as it would require complex
|
||||
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
|
||||
// the ~5% fusion benefit.
|
||||
func (attn *Attention) FuseQKV() {
|
||||
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip fusion for quantized weights - type assert to check
|
||||
toQ, qOk := attn.ToQ.(*nn.Linear)
|
||||
toK, kOk := attn.ToK.(*nn.Linear)
|
||||
toV, vOk := attn.ToV.(*nn.Linear)
|
||||
if !qOk || !kOk || !vOk {
|
||||
// One or more are QuantizedLinear, skip fusion
|
||||
return
|
||||
}
|
||||
|
||||
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
|
||||
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
|
||||
qWeight := toQ.Weight
|
||||
kWeight := toK.Weight
|
||||
vWeight := toV.Weight
|
||||
|
||||
// Concatenate along output dimension (axis 0)
|
||||
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
|
||||
|
||||
// Evaluate fused weight to ensure it's materialized
|
||||
mlx.Eval(fusedWeight)
|
||||
|
||||
// Create fused linear layer
|
||||
fusedLinear := &nn.Linear{Weight: fusedWeight}
|
||||
|
||||
// Handle bias if present
|
||||
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
|
||||
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
|
||||
mlx.Eval(fusedBias)
|
||||
fusedLinear.Bias = fusedBias
|
||||
}
|
||||
|
||||
attn.ToQKV = fusedLinear
|
||||
attn.Fused = true
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
@@ -148,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
var q, k, v *mlx.Array
|
||||
if attn.Fused && attn.ToQKV != nil {
|
||||
// Fused QKV path: single matmul then split
|
||||
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
|
||||
|
||||
// Split into Q, K, V along last dimension
|
||||
// Each has shape [B*L, dim]
|
||||
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
|
||||
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
|
||||
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
|
||||
} else {
|
||||
// Separate Q, K, V projections
|
||||
q = attn.ToQ.Forward(xFlat)
|
||||
k = attn.ToK.Forward(xFlat)
|
||||
v = attn.ToV.Forward(xFlat)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
@@ -229,7 +294,7 @@ type TransformerBlock struct {
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -283,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
@@ -335,43 +400,50 @@ type Transformer struct {
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Z-Image transformer...")
|
||||
// Load loads the Z-Image transformer from ollama blob storage.
|
||||
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = cfg
|
||||
|
||||
// Pre-allocate slices for loader
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize computed fields
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Transformer) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
@@ -383,26 +455,20 @@ func (m *Transformer) Load(path string) error {
|
||||
for _, block := range m.Layers {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTransformerConfig loads transformer config from a JSON file
|
||||
func loadTransformerConfig(path string) (*TransformerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
|
||||
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
|
||||
func (m *Transformer) FuseAllQKV() {
|
||||
for _, block := range m.NoiseRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
var cfg TransformerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
for _, block := range m.ContextRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
// Extract PatchSize from array
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
for _, block := range m.Layers {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
@@ -418,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
@@ -437,6 +503,8 @@ type RoPECache struct {
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
GridH int32 // Image token grid height
|
||||
GridW int32 // Image token grid width
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
@@ -470,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
GridH: hTok,
|
||||
GridW: wTok,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -25,19 +23,6 @@ type VAEConfig struct {
|
||||
ShiftFactor float32 `json:"shift_factor"`
|
||||
}
|
||||
|
||||
// loadVAEConfig loads VAE config from a JSON file
|
||||
func loadVAEConfig(path string) (*VAEConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg VAEConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array
|
||||
@@ -57,49 +42,189 @@ func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, C, H, W]
|
||||
// x: [B, H, W, C] (NHWC format)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Reshape to [B, groups, C/groups, H, W]
|
||||
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
|
||||
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
|
||||
// To be safe, tile when H*W > 512*512 = 262144
|
||||
if H*W > 512*512 {
|
||||
return gn.forwardTiled(x, B, H, W, C)
|
||||
}
|
||||
|
||||
return gn.forwardSmall(x, B, H, W, C)
|
||||
}
|
||||
|
||||
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
|
||||
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
||||
// Reshape to [B, H, W, groups, C/groups]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
||||
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 2, true)
|
||||
mean = mlx.Mean(mean, 3, true)
|
||||
// Compute mean and variance per group (over H, W, and C/groups dimensions)
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
||||
variance = mlx.Mean(variance, 3, true)
|
||||
|
||||
// Variance over same axes
|
||||
sq := mlx.Square(xCentered)
|
||||
variance := mlx.Mean(sq, 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, C, H, W]
|
||||
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
||||
// Reshape back to [B, H, W, C]
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift (weight and bias are [C])
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
|
||||
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
||||
groupSize := C / gn.NumGroups
|
||||
|
||||
// Keep the input - we need it for slicing tiles later
|
||||
// Track if we were the ones who kept it, so we can restore state after
|
||||
wasKept := x.Kept()
|
||||
mlx.Keep(x)
|
||||
|
||||
// Compute per-group mean and variance using flattened spatial dimensions
|
||||
// Build the entire compute graph first, then eval once
|
||||
// Reshape to [B, H*W, groups, groupSize]
|
||||
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
|
||||
|
||||
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
|
||||
// Result shape: [B, 1, groups, 1]
|
||||
mean1 := mlx.Mean(xFlat, 1, true)
|
||||
mean := mlx.Mean(mean1, 3, true)
|
||||
|
||||
// Variance using E[X^2] - E[X]^2
|
||||
xSq := mlx.Square(xFlat)
|
||||
meanSq1 := mlx.Mean(xSq, 1, true)
|
||||
meanSq := mlx.Mean(meanSq1, 3, true)
|
||||
meanSquared := mlx.Square(mean)
|
||||
variance := mlx.Sub(meanSq, meanSquared)
|
||||
|
||||
// invStd = 1/sqrt(var + eps)
|
||||
varPlusEps := mlx.AddScalar(variance, gn.Eps)
|
||||
stdDev := mlx.Sqrt(varPlusEps)
|
||||
one := mlx.Full(1.0, 1)
|
||||
invStd := mlx.Div(one, stdDev)
|
||||
|
||||
// Eval mean and invStd together - these are what we need for the tile loop
|
||||
mlx.Keep(mean, invStd)
|
||||
mlx.Eval(mean, invStd)
|
||||
|
||||
// Tile along H dimension
|
||||
tileH := int32(512 * 512 / W)
|
||||
if tileH < 1 {
|
||||
tileH = 1
|
||||
}
|
||||
if tileH > H {
|
||||
tileH = H
|
||||
}
|
||||
|
||||
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
|
||||
var weightGN, biasGN *mlx.Array
|
||||
if gn.Weight != nil {
|
||||
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
|
||||
mlx.Keep(weightGN)
|
||||
mlx.Eval(weightGN)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
|
||||
mlx.Keep(biasGN)
|
||||
mlx.Eval(biasGN)
|
||||
}
|
||||
|
||||
var tiles []*mlx.Array
|
||||
for hStart := int32(0); hStart < H; hStart += tileH {
|
||||
hEnd := hStart + tileH
|
||||
if hEnd > H {
|
||||
hEnd = H
|
||||
}
|
||||
tileHeight := hEnd - hStart
|
||||
spatialSize := tileHeight * W
|
||||
|
||||
// Build the compute graph for this tile (no intermediate Evals)
|
||||
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
|
||||
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
|
||||
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
|
||||
|
||||
// Normalize: (x - mean) * invStd
|
||||
tileCentered := mlx.Sub(tileFlat, mean)
|
||||
tileNorm := mlx.Mul(tileCentered, invStd)
|
||||
|
||||
// Apply scale and shift in 4D space
|
||||
if weightGN != nil {
|
||||
tileNorm = mlx.Mul(tileNorm, weightGN)
|
||||
}
|
||||
if biasGN != nil {
|
||||
tileNorm = mlx.Add(tileNorm, biasGN)
|
||||
}
|
||||
|
||||
// Reshape back to [B, tileH, W, C]
|
||||
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
|
||||
|
||||
// Now eval and keep this tile
|
||||
mlx.Keep(tileOut)
|
||||
mlx.Eval(tileOut)
|
||||
|
||||
tiles = append(tiles, tileOut)
|
||||
}
|
||||
|
||||
// Concatenate tiles along H axis
|
||||
var result *mlx.Array
|
||||
if len(tiles) == 1 {
|
||||
result = tiles[0]
|
||||
} else {
|
||||
result = mlx.Concatenate(tiles, 1)
|
||||
mlx.Eval(result)
|
||||
// Free the individual tiles now that they're concatenated
|
||||
for _, t := range tiles {
|
||||
t.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up kept arrays
|
||||
// Restore x's kept state - only free if we were the ones who kept it
|
||||
if !wasKept {
|
||||
x.Free()
|
||||
}
|
||||
mean.Free()
|
||||
invStd.Free()
|
||||
if weightGN != nil {
|
||||
weightGN.Free()
|
||||
}
|
||||
if biasGN != nil {
|
||||
biasGN.Free()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer
|
||||
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
||||
// Works natively in NHWC format (MLX's native format)
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
Bias *mlx.Array // [out_channels]
|
||||
@@ -123,21 +248,17 @@ func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
||||
}
|
||||
|
||||
// Forward applies convolution
|
||||
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
||||
// Input and output are in NHWC format [N, H, W, C]
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [N, C, H, W] -> [N, H, W, C]
|
||||
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
||||
|
||||
// Conv in NHWC format
|
||||
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
||||
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
||||
// Conv in NHWC format (MLX native)
|
||||
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
||||
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
|
||||
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -151,7 +272,7 @@ type ResnetBlock2D struct {
|
||||
}
|
||||
|
||||
// NewResnetBlock2D creates a ResNet block
|
||||
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -216,13 +337,13 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 1: norm1
|
||||
{
|
||||
h = rb.Norm1.Forward(x)
|
||||
h = rb.Norm1.Forward(x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: silu + conv1
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
@@ -231,7 +352,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 3: norm2
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = rb.Norm2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
@@ -239,7 +360,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 4: silu + conv2
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
@@ -248,7 +369,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Residual connection
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
if rb.ConvShortcut != nil {
|
||||
shortcut := rb.ConvShortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
@@ -277,7 +398,7 @@ type VAEAttentionBlock struct {
|
||||
}
|
||||
|
||||
// NewVAEAttentionBlock creates an attention block
|
||||
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -338,20 +459,20 @@ func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numG
|
||||
}
|
||||
|
||||
// Forward applies attention with staged evaluation
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: GroupNorm + reshape
|
||||
// Stage 1: GroupNorm + reshape to [B, H*W, C]
|
||||
{
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Transpose(h, 0, 2, 3, 1)
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
@@ -360,7 +481,7 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 2: Q, K, V projections + attention
|
||||
{
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q = mlx.Add(q, ab.ToQBias)
|
||||
k := mlx.Linear(h, ab.ToKWeight)
|
||||
k = mlx.Add(k, ab.ToKBias)
|
||||
@@ -380,11 +501,10 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 3: Output projection + reshape + residual
|
||||
{
|
||||
prev := out
|
||||
prev := out
|
||||
out = mlx.Linear(out, ab.ToOutWeight)
|
||||
out = mlx.Add(out, ab.ToOutBias)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Transpose(out, 0, 3, 1, 2)
|
||||
out = mlx.Add(out, residual)
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
@@ -400,7 +520,7 @@ type UpDecoderBlock2D struct {
|
||||
}
|
||||
|
||||
// NewUpDecoderBlock2D creates an up decoder block
|
||||
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
@@ -467,7 +587,7 @@ type VAEMidBlock struct {
|
||||
}
|
||||
|
||||
// NewVAEMidBlock creates the mid block
|
||||
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -518,22 +638,31 @@ type VAEDecoder struct {
|
||||
ConvOut *Conv2D
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading VAE decoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = cfg
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights, &cfg)
|
||||
}
|
||||
|
||||
// loadWeights loads VAE weights from any WeightSource
|
||||
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load conv_in
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
@@ -596,57 +725,79 @@ func (m *VAEDecoder) Load(path string) error {
|
||||
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
// Input latents are in NCHW format, output is in NCHW format.
|
||||
// Internally uses NHWC format (MLX native) for all operations.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
{
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
h = vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
// Scale latents
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
// Convert NCHW -> NHWC for internal processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
{
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
prev = h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
||||
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using Take.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
// Uses Take with repeated indices to produce contiguous output.
|
||||
func Upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
// Broadcast to [B, C, H, 2, W, 2]
|
||||
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
||||
// Reshape to [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H*2, W*2)
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
// For H dimension
|
||||
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
||||
hIdx = mlx.Reshape(hIdx, H, 1)
|
||||
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
||||
hIdx = mlx.Reshape(hIdx, H*2)
|
||||
|
||||
// For W dimension
|
||||
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
||||
wIdx = mlx.Reshape(wIdx, W, 1)
|
||||
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
||||
wIdx = mlx.Reshape(wIdx, W*2)
|
||||
|
||||
// Take along H axis (axis 1 in NHWC)
|
||||
x = mlx.Take(x, hIdx, 1)
|
||||
// Take along W axis (axis 2 in NHWC)
|
||||
x = mlx.Take(x, wIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ package zimage
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -26,10 +26,12 @@ type GenerateConfig struct {
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
|
||||
|
||||
// Fused QKV (fuse Q/K/V projections into single matmul)
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
@@ -37,16 +39,17 @@ type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
ModelName string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
qkvFused bool // Track if QKV has been fused (do only once)
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Z-Image model...")
|
||||
// Load loads the Z-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading Z-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
@@ -54,12 +57,34 @@ func (m *Model) Load(modelPath string) error {
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load tokenizer
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Try to read tokenizer config files from manifest
|
||||
tokConfig := &tokenizer.TokenizerConfig{}
|
||||
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
|
||||
tokConfig.SpecialTokensMapJSON = data
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
@@ -68,7 +93,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
@@ -78,7 +103,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
@@ -88,7 +113,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
@@ -104,7 +129,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
@@ -115,7 +140,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
@@ -127,7 +152,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
@@ -140,9 +165,9 @@ func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -160,7 +185,7 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
@@ -174,13 +199,17 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
// TeaCache enabled by default
|
||||
cfg.TeaCache = true
|
||||
if cfg.TeaCacheThreshold <= 0 {
|
||||
cfg.TeaCacheThreshold = 0.15
|
||||
}
|
||||
|
||||
// Enable fused QKV if requested (only fuse once)
|
||||
if cfg.FusedQKV && !m.qkvFused {
|
||||
m.Transformer.FuseAllQKV()
|
||||
m.qkvFused = true
|
||||
fmt.Println(" Fused QKV enabled")
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
@@ -238,20 +267,71 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// TeaCache for timestep-aware caching
|
||||
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
|
||||
var teaCache *cache.TeaCache
|
||||
if cfg.TeaCache {
|
||||
skipEarly := 0
|
||||
if useCFG {
|
||||
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
|
||||
}
|
||||
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
|
||||
Threshold: cfg.TeaCacheThreshold,
|
||||
RescaleFactor: 1.0,
|
||||
SkipEarlySteps: skipEarly,
|
||||
})
|
||||
if useCFG {
|
||||
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
|
||||
} else {
|
||||
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup frees all kept arrays when we need to abort early
|
||||
cleanup := func() {
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
teaCache.Free()
|
||||
}
|
||||
latents.Free()
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(0, cfg.Steps) // Start at 0%
|
||||
}
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cleanup()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
stepStart := time.Now()
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
@@ -259,49 +339,77 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
var noisePred *mlx.Array
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
// TeaCache: check if we should compute or reuse cached output
|
||||
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if shouldCompute {
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile patches: [1, L, D] -> [2, L, D]
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
// Tile timestep: [1] -> [2]
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
outputShape := batchedOutput.Shape()
|
||||
L := outputShape[1]
|
||||
D := outputShape[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
// Convert to noise predictions (unpatchify and negate)
|
||||
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
posPred = mlx.Neg(posPred)
|
||||
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
negPred = mlx.Neg(negPred)
|
||||
|
||||
// Cache pos/neg separately for TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
|
||||
// Apply CFG: noisePred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
} else {
|
||||
// Non-CFG forward pass
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
|
||||
// Update TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCache(noisePred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
}
|
||||
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
|
||||
// CFG mode: get cached pos/neg and compute CFG fresh
|
||||
posPred, negPred := teaCache.GetCFGCached()
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
|
||||
} else {
|
||||
// Non-CFG mode: reuse cached noise prediction
|
||||
noisePred = teaCache.GetCached()
|
||||
fmt.Printf(" [TeaCache: reusing cached output]\n")
|
||||
}
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
@@ -313,6 +421,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
|
||||
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps) // Report completed step
|
||||
}
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
@@ -326,8 +438,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
hits, misses := teaCache.Stats()
|
||||
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
|
||||
hits, misses, float64(hits)/float64(hits+misses)*100)
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
|
||||
@@ -10,6 +10,13 @@ type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LinearLayer is an interface for linear layers (both regular and quantized).
|
||||
// This allows swapping between Linear and QuantizedLinear at runtime.
|
||||
type LinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
OutputDim() int32 // Returns the output dimension of the layer
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the linear layer.
|
||||
func (l *Linear) OutputDim() int32 {
|
||||
return l.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return out
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the quantized linear layer.
|
||||
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
|
||||
// The output dimension is the first dimension of the weight.
|
||||
func (ql *QuantizedLinear) OutputDim() int32 {
|
||||
return ql.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
|
||||
22
x/imagegen/quantize.go
Normal file
22
x/imagegen/quantize.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
267
x/imagegen/runner/runner.go
Normal file
267
x/imagegen/runner/runner.go
Normal file
@@ -0,0 +1,267 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package runner provides a subprocess server for image generation.
|
||||
// It listens on a port and handles HTTP requests for image generation.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
modelName string
|
||||
modelType string // "zimage" or "glm_image"
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to image model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
requiredMemory := imagegen.EstimateVRAM(*modelName)
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType, err := detectModelType(*modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to detect model type: %w", err)
|
||||
}
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "GlmImagePipeline":
|
||||
slog.Info("loading GLM-Image model")
|
||||
m := &glm_image.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load GLM-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
|
||||
slog.Info("loading Z-Image model")
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load Z-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
modelType: modelType,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down image runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("image runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
// Default steps depend on model type
|
||||
switch s.modelType {
|
||||
case "GlmImagePipeline":
|
||||
req.Steps = 50 // GLM-Image default
|
||||
default:
|
||||
req.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
}
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using interface method
|
||||
ctx := r.Context()
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// detectModelType reads the model manifest and returns the pipeline class name
|
||||
func detectModelType(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return "ZImagePipeline", nil // Default to Z-Image
|
||||
}
|
||||
|
||||
// Try both _class_name (diffusers format) and architecture (ollama format)
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
|
||||
// Prefer _class_name, fall back to architecture
|
||||
className := index.ClassName
|
||||
if className == "" {
|
||||
className = index.Architecture
|
||||
}
|
||||
if className == "" {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
return className, nil
|
||||
}
|
||||
10
x/imagegen/runner/runner_stub.go
Normal file
10
x/imagegen/runner/runner_stub.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !mlx
|
||||
|
||||
package runner
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("image generation not available: build with mlx tag")
|
||||
}
|
||||
176
x/imagegen/safetensors/extractor.go
Normal file
176
x/imagegen/safetensors/extractor.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// tensorInfo holds tensor metadata from safetensors headers.
|
||||
// This avoids depending on safetensors.go which requires the mlx tag.
|
||||
type tensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// TensorExtractor extracts individual tensors from a safetensors file.
|
||||
// It provides io.Reader interfaces for each tensor's raw data, enabling
|
||||
// streaming writes to blobs without loading entire tensors into memory.
|
||||
type TensorExtractor struct {
|
||||
file *os.File
|
||||
dataOffset int64 // Start of tensor data region
|
||||
header map[string]tensorInfo
|
||||
}
|
||||
|
||||
// TensorData holds tensor metadata and a reader for its raw bytes.
|
||||
type TensorData struct {
|
||||
Name string
|
||||
Dtype string
|
||||
Shape []int32
|
||||
Size int64
|
||||
reader *io.SectionReader
|
||||
}
|
||||
|
||||
// Reader returns an io.Reader for the tensor's raw bytes.
|
||||
func (td *TensorData) Reader() io.Reader {
|
||||
return td.reader
|
||||
}
|
||||
|
||||
// SafetensorsReader returns a reader that outputs the tensor wrapped in
|
||||
// minimal safetensors format. This allows using mlx_load_safetensors on
|
||||
// individual tensor blobs for native zero-copy loading.
|
||||
func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
// Build minimal safetensors header with tensor named "data"
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Build header with size prefix
|
||||
headerBuf := new(bytes.Buffer)
|
||||
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
headerBuf.Write(headerJSON)
|
||||
|
||||
// Return multi-reader: header + tensor data
|
||||
td.reader.Seek(0, io.SeekStart)
|
||||
return io.MultiReader(headerBuf, td.reader)
|
||||
}
|
||||
|
||||
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
|
||||
func (td *TensorData) SafetensorsSize() int64 {
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
|
||||
}
|
||||
|
||||
// OpenForExtraction opens a safetensors file for tensor extraction.
|
||||
// The caller must call Close() when done.
|
||||
func OpenForExtraction(path string) (*TensorExtractor, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := f.Read(headerBytes); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header map[string]tensorInfo
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
delete(header, "__metadata__")
|
||||
|
||||
return &TensorExtractor{
|
||||
file: f,
|
||||
dataOffset: 8 + int64(headerSize), // 8 bytes for header size + header content
|
||||
header: header,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTensor returns tensor metadata and a reader for extracting a single tensor.
|
||||
func (te *TensorExtractor) GetTensor(name string) (*TensorData, error) {
|
||||
info, ok := te.header[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
|
||||
start := te.dataOffset + int64(info.DataOffsets[0])
|
||||
size := int64(info.DataOffsets[1] - info.DataOffsets[0])
|
||||
|
||||
return &TensorData{
|
||||
Name: name,
|
||||
Dtype: info.Dtype,
|
||||
Shape: info.Shape,
|
||||
Size: size,
|
||||
reader: io.NewSectionReader(te.file, start, size),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names in sorted order.
|
||||
func (te *TensorExtractor) ListTensors() []string {
|
||||
names := make([]string, 0, len(te.header))
|
||||
for name := range te.header {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// TensorCount returns the number of tensors in the file.
|
||||
func (te *TensorExtractor) TensorCount() int {
|
||||
return len(te.header)
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (te *TensorExtractor) Close() error {
|
||||
return te.file.Close()
|
||||
}
|
||||
|
||||
// ExtractAll returns TensorData for all tensors in the file.
|
||||
// Each TensorData has a reader that reads from the original file.
|
||||
// The caller must call Close() on the TensorExtractor when done.
|
||||
func (te *TensorExtractor) ExtractAll() ([]*TensorData, error) {
|
||||
names := te.ListTensors()
|
||||
tensors := make([]*TensorData, 0, len(names))
|
||||
|
||||
for _, name := range names {
|
||||
td, err := te.GetTensor(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, td)
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
@@ -8,8 +8,17 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// WeightSource is an interface for loading weights.
|
||||
// Both ModelWeights (directory-based) and ManifestWeights (blob-based) implement this.
|
||||
type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
}
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
//
|
||||
// Struct tags use the format: `weight:"path[,optional]"`
|
||||
@@ -31,7 +40,7 @@ import (
|
||||
// }
|
||||
//
|
||||
// err := LoadModule(&attn, weights, "model.layers.0")
|
||||
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
func LoadModule(dst any, weights WeightSource, prefix string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
|
||||
@@ -51,7 +60,7 @@ func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
}
|
||||
|
||||
// loadStruct recursively loads weights into a struct value.
|
||||
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
|
||||
func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]string, parentOptional bool) {
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
@@ -94,6 +103,22 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
layer, err := LoadLinearLayer(weights, fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath+": "+err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(layer))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
@@ -136,7 +161,7 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
|
||||
}
|
||||
|
||||
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
|
||||
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
func hasWeightsWithPrefix(weights WeightSource, prefix string) bool {
|
||||
for _, name := range weights.ListTensors() {
|
||||
if strings.HasPrefix(name, prefix+".") || name == prefix {
|
||||
return true
|
||||
@@ -146,7 +171,7 @@ func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
}
|
||||
|
||||
// loadSlice loads weights into each element of a slice of struct pointers.
|
||||
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
|
||||
func loadSlice(v reflect.Value, weights WeightSource, prefix string, errs *[]string) {
|
||||
elemStructType := v.Type().Elem().Elem()
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
@@ -168,3 +193,64 @@ func joinPath(prefix, suffix string) string {
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
scales, err := weights.GetTensor(scalePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
// Load as regular Linear
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
return nn.NewLinear(weight, bias), nil
|
||||
}
|
||||
|
||||
@@ -118,6 +118,34 @@ func LoadModelWeights(dir string) (*ModelWeights, error) {
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// LoadModelWeightsFromPaths loads weights from specific safetensor file paths.
|
||||
// Used for loading from blob storage where files are not in a directory.
|
||||
func LoadModelWeightsFromPaths(paths []string) (*ModelWeights, error) {
|
||||
mw := &ModelWeights{
|
||||
tensorFiles: make(map[string]string),
|
||||
tensorInfo: make(map[string]TensorInfo),
|
||||
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
header, err := parseSafetensorHeader(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
for name, info := range header {
|
||||
mw.tensorFiles[name] = path
|
||||
mw.tensorInfo[name] = info
|
||||
}
|
||||
}
|
||||
|
||||
if len(mw.tensorFiles) == 0 {
|
||||
return nil, fmt.Errorf("no tensors found in provided paths")
|
||||
}
|
||||
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// Load loads all tensors into cache with the specified dtype.
|
||||
// If dtype is 0, tensors are loaded in their original dtype.
|
||||
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
||||
|
||||
395
x/imagegen/server.go
Normal file
395
x/imagegen/server.go
Normal file
@@ -0,0 +1,395 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string // Last stderr line for error reporting
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||
|
||||
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: EstimateVRAM(modelName),
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("image-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
s.done <- err
|
||||
}()
|
||||
|
||||
// Wait for subprocess to be ready
|
||||
if err := s.waitUntilRunning(); err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ModelPath returns the path to the model.
|
||||
func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ping checks if the subprocess is healthy.
|
||||
func (s *Server) Ping(ctx context.Context) error {
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitUntilRunning waits for the subprocess to be ready.
|
||||
func (s *Server) waitUntilRunning() error {
|
||||
ctx := context.Background()
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
if err := s.Ping(ctx); err == nil {
|
||||
slog.Info("image runner is ready", "port", s.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses - use large buffer for base64 image data
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content := cresp.Content
|
||||
// If this is the final response with an image, encode it in the content
|
||||
if cresp.Done && cresp.Image != "" {
|
||||
content = "IMAGE_BASE64:" + cresp.Image
|
||||
}
|
||||
|
||||
fn(llm.CompletionResponse{
|
||||
Content: content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
if cresp.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
s.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
select {
|
||||
case <-s.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
s.cmd.Process.Kill()
|
||||
}
|
||||
s.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
return s.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Server implements llm.LlamaServer
|
||||
var _ llm.LlamaServer = (*Server)(nil)
|
||||
82
x/imagegen/server_test.go
Normal file
82
x/imagegen/server_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPlatformSupport verifies platform validation works correctly.
|
||||
func TestPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
// Apple Silicon should be supported
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Intel Mac should fail
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/amd64 (Intel), got nil")
|
||||
}
|
||||
if err != nil && err.Error() == "" {
|
||||
t.Error("Expected meaningful error message for unsupported platform")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
// Linux/Windows are allowed (CUDA support checked at runtime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
// Other platforms should fail
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryRequirementsError verifies memory check returns clear error.
|
||||
func TestMemoryRequirementsError(t *testing.T) {
|
||||
// Test with insufficient memory
|
||||
err := CheckMemoryRequirements("test-model", 8*GB)
|
||||
if err == nil {
|
||||
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
|
||||
}
|
||||
|
||||
// Test with sufficient memory
|
||||
err = CheckMemoryRequirements("test-model", 32*GB)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
|
||||
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
|
||||
// Unknown model should return default (21GB)
|
||||
vram := EstimateVRAM("unknown-model")
|
||||
if vram < 10*GB || vram > 100*GB {
|
||||
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
|
||||
}
|
||||
|
||||
// Verify known pipeline estimates exist and are reasonable
|
||||
for name, estimate := range modelVRAMEstimates {
|
||||
if estimate < 10*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
|
||||
}
|
||||
if estimate > 200*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||
// This is a compile-time check but we document it as a test.
|
||||
func TestServerInterfaceCompliance(t *testing.T) {
|
||||
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
|
||||
// ensures compile-time interface compliance.
|
||||
// This test documents that requirement.
|
||||
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
|
||||
}
|
||||
@@ -256,6 +256,164 @@ func rewritePatternForRE2(pattern string) string {
|
||||
return pattern
|
||||
}
|
||||
|
||||
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
||||
// This is useful when loading from blob storage where the file content is already in memory.
|
||||
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
||||
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
||||
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
||||
return loadFromTokenizerJSON(data, "")
|
||||
}
|
||||
|
||||
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
||||
type TokenizerConfig struct {
|
||||
TokenizerConfigJSON []byte // tokenizer_config.json content
|
||||
GenerationConfigJSON []byte // generation_config.json content
|
||||
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
||||
ConfigJSON []byte // config.json content
|
||||
}
|
||||
|
||||
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
||||
// This is useful when loading from blob storage where companion config files are also blobs.
|
||||
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
||||
t, err := loadFromTokenizerJSON(data, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Apply special token configs from provided data
|
||||
loadSpecialTokenConfigFromBytes(t, config)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
||||
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
||||
// Helper to parse eos_token_id which can be int or []int
|
||||
parseTokenIDs := func(v interface{}) []int32 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return []int32{int32(val)}
|
||||
case []interface{}:
|
||||
ids := make([]int32, 0, len(val))
|
||||
for _, id := range val {
|
||||
if f, ok := id.(float64); ok {
|
||||
ids = append(ids, int32(f))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: generation_config.json
|
||||
if len(config.GenerationConfigJSON) > 0 {
|
||||
var genConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
|
||||
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: config.json
|
||||
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
||||
var modelConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
}
|
||||
if t.vocab.BOS < 0 {
|
||||
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: tokenizer_config.json
|
||||
if len(config.TokenizerConfigJSON) > 0 {
|
||||
var tokConfig struct {
|
||||
BOSToken interface{} `json:"bos_token"`
|
||||
EOSToken interface{} `json:"eos_token"`
|
||||
PADToken interface{} `json:"pad_token"`
|
||||
AddBOSToken *bool `json:"add_bos_token"`
|
||||
AddEOSToken *bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if tokConfig.AddBOSToken != nil {
|
||||
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
||||
}
|
||||
if tokConfig.AddEOSToken != nil {
|
||||
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: special_tokens_map.json
|
||||
if len(config.SpecialTokensMapJSON) > 0 {
|
||||
var tokensMap map[string]interface{}
|
||||
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads a tokenizer from a path which can be:
|
||||
// - A tokenizer.json file
|
||||
// - A directory containing tokenizer.json or vocab.json + merges.txt
|
||||
|
||||
329
x/imagegen/transfer/download.go
Normal file
329
x/imagegen/transfer/download.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
var (
|
||||
errStalled = errors.New("download stalled")
|
||||
errSlow = errors.New("download too slow")
|
||||
)
|
||||
|
||||
type downloader struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
destDir string
|
||||
repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
token *string
|
||||
getToken func(context.Context, AuthChallenge) (string, error)
|
||||
userAgent string
|
||||
stallTimeout time.Duration
|
||||
progress *progressTracker
|
||||
speeds *speedTracker
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func download(ctx context.Context, opts DownloadOptions) error {
|
||||
if len(opts.Blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate total from all blobs (for accurate progress reporting on resume)
|
||||
var total int64
|
||||
for _, b := range opts.Blobs {
|
||||
total += b.Size
|
||||
}
|
||||
|
||||
// Filter out already-downloaded blobs and track completed bytes
|
||||
var blobs []Blob
|
||||
var alreadyCompleted int64
|
||||
for _, b := range opts.Blobs {
|
||||
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
||||
if opts.Logger != nil {
|
||||
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
||||
}
|
||||
alreadyCompleted += b.Size
|
||||
continue
|
||||
}
|
||||
blobs = append(blobs, b)
|
||||
}
|
||||
if len(blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
progress := newProgressTracker(total, opts.Progress)
|
||||
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
|
||||
|
||||
d := &downloader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
destDir: opts.DestDir,
|
||||
repository: cmp.Or(opts.Repository, "library/_"),
|
||||
token: &token,
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
||||
progress: progress,
|
||||
speeds: &speedTracker{},
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
|
||||
sem := semaphore.NewWeighted(int64(concurrency))
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
for _, blob := range blobs {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
return d.download(ctx, blob)
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (d *downloader) download(ctx context.Context, blob Blob) error {
|
||||
var lastErr error
|
||||
var slowRetries int
|
||||
attempt := 0
|
||||
|
||||
for attempt < maxRetries {
|
||||
if attempt > 0 {
|
||||
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
n, err := d.downloadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
if s := time.Since(start).Seconds(); s > 0 {
|
||||
d.speeds.record(float64(blob.Size) / s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
d.progress.add(-n) // rollback
|
||||
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return err
|
||||
case errors.Is(err, errStalled):
|
||||
// Don't count stall retries against limit
|
||||
case errors.Is(err, errSlow):
|
||||
if slowRetries++; slowRetries >= 3 {
|
||||
attempt++ // Only count after 3 slow retries
|
||||
}
|
||||
default:
|
||||
attempt++
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
||||
}
|
||||
|
||||
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
|
||||
if d.logger != nil {
|
||||
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
|
||||
}
|
||||
|
||||
baseURL, _ := url.Parse(d.baseURL)
|
||||
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
// Add auth only for same-host (not CDN)
|
||||
if u.Host == baseURL.Host && *d.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*d.token)
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return 0, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return d.save(ctx, blob, resp.Body)
|
||||
}
|
||||
|
||||
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
|
||||
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
||||
tmp := dest + ".tmp"
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
setSparse(f)
|
||||
|
||||
h := sha256.New()
|
||||
n, err := d.copy(ctx, f, r, h)
|
||||
if err != nil {
|
||||
os.Remove(tmp)
|
||||
return n, err
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("digest mismatch")
|
||||
}
|
||||
if n != blob.Size {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("size mismatch")
|
||||
}
|
||||
return n, os.Rename(tmp, dest)
|
||||
}
|
||||
|
||||
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
|
||||
var n int64
|
||||
var lastRead atomic.Int64
|
||||
lastRead.Store(time.Now().UnixNano())
|
||||
start := time.Now()
|
||||
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
go func() {
|
||||
tick := time.NewTicker(time.Second)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-tick.C:
|
||||
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
|
||||
cancel(errStalled)
|
||||
return
|
||||
}
|
||||
if e := time.Since(start); e > 5*time.Second {
|
||||
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
|
||||
cancel(errSlow)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if c := context.Cause(ctx); c != nil {
|
||||
return n, c
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
nr, err := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastRead.Store(time.Now().UnixNano())
|
||||
dst.Write(buf[:nr])
|
||||
h.Write(buf[:nr])
|
||||
d.progress.add(int64(nr))
|
||||
n += int64(nr)
|
||||
}
|
||||
if err == io.EOF {
|
||||
return n, nil
|
||||
}
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
|
||||
u, _ := url.Parse(rawURL)
|
||||
for range 10 {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
if *d.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*d.token)
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
return u, nil
|
||||
case http.StatusUnauthorized:
|
||||
if d.getToken == nil {
|
||||
return nil, fmt.Errorf("unauthorized")
|
||||
}
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *d.token, err = d.getToken(ctx, ch); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
|
||||
loc, _ := resp.Location()
|
||||
if loc.Host != u.Host {
|
||||
return loc, nil
|
||||
}
|
||||
u = loc
|
||||
default:
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("too many redirects")
|
||||
}
|
||||
|
||||
type speedTracker struct {
|
||||
mu sync.Mutex
|
||||
speeds []float64
|
||||
}
|
||||
|
||||
func (s *speedTracker) record(v float64) {
|
||||
s.mu.Lock()
|
||||
s.speeds = append(s.speeds, v)
|
||||
if len(s.speeds) > 30 {
|
||||
s.speeds = s.speeds[1:]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *speedTracker) median() float64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.speeds) < 5 {
|
||||
return 0
|
||||
}
|
||||
sorted := make([]float64, len(s.speeds))
|
||||
copy(sorted, s.speeds)
|
||||
slices.Sort(sorted)
|
||||
return sorted[len(sorted)/2]
|
||||
}
|
||||
|
||||
const defaultStallTimeout = 10 * time.Second
|
||||
12
x/imagegen/transfer/sparse_other.go
Normal file
12
x/imagegen/transfer/sparse_other.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package transfer
|
||||
|
||||
import "os"
|
||||
|
||||
// setSparse is a no-op on non-Windows platforms.
|
||||
// On Windows, this sets the FSCTL_SET_SPARSE attribute which allows the OS
|
||||
// to not allocate disk blocks for zero-filled regions. This is useful for
|
||||
// partial downloads where not all data has been written yet. On Unix-like
|
||||
// systems, filesystems typically handle this automatically (sparse by default).
|
||||
func setSparse(_ *os.File) {}
|
||||
31
x/imagegen/transfer/sparse_windows.go
Normal file
31
x/imagegen/transfer/sparse_windows.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build windows
|
||||
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// setSparse sets the FSCTL_SET_SPARSE attribute on Windows files.
|
||||
// This allows the OS to not allocate disk blocks for zero-filled regions,
|
||||
// which is useful for large files that may not be fully written (e.g., partial
|
||||
// downloads). Without this, Windows may pre-allocate disk space for the full
|
||||
// file size even if most of it is zeros.
|
||||
//
|
||||
// Note: Errors are intentionally ignored because:
|
||||
// 1. The file will still work correctly without sparse support
|
||||
// 2. Not all Windows filesystems support sparse files (e.g., FAT32)
|
||||
// 3. This is an optimization, not a requirement
|
||||
func setSparse(file *os.File) {
|
||||
var bytesReturned uint32
|
||||
_ = windows.DeviceIoControl(
|
||||
windows.Handle(file.Fd()),
|
||||
windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
&bytesReturned,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
216
x/imagegen/transfer/transfer.go
Normal file
216
x/imagegen/transfer/transfer.go
Normal file
@@ -0,0 +1,216 @@
|
||||
// Package transfer provides minimal, fast blob transfer for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It provides optimized transfer for models with many small blobs (tensor models)
|
||||
// rather than few large blobs (typical LLMs).
|
||||
//
|
||||
// TODO (jmorganca): Integrate into server/download.go and server/upload.go when stable.
|
||||
//
|
||||
// Design Philosophy:
|
||||
// This package is intentionally simpler than the main server's download/upload code.
|
||||
// Key simplifications for many-small-blob workloads:
|
||||
//
|
||||
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
|
||||
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
|
||||
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
|
||||
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
|
||||
//
|
||||
// For large models (multi-GB), use the server's download/upload code which has:
|
||||
// - Part-based transfers with 64MB chunks
|
||||
// - Resumable downloads with JSON state files
|
||||
// - Async streamHasher that hashes from OS page cache as parts complete
|
||||
// - Speed tracking with rolling median to detect and restart slow parts
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Blob represents a content-addressed blob to transfer.
|
||||
type Blob struct {
|
||||
Digest string // sha256:...
|
||||
Size int64
|
||||
|
||||
// From enables cross-repository blob mounting (upload only).
|
||||
// When set, the upload will first attempt to mount the blob from this source
|
||||
// repository instead of uploading the data. This is a Docker Registry v2 API
|
||||
// feature that avoids re-uploading blobs that already exist elsewhere.
|
||||
//
|
||||
// Example: From="library/source-model" will add ?mount=<digest>&from=library/source-model
|
||||
// to the POST /blobs/uploads/ request. If the registry returns 201 Created,
|
||||
// the blob was mounted successfully and no upload is needed.
|
||||
//
|
||||
// See: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
|
||||
From string
|
||||
}
|
||||
|
||||
// DownloadOptions configures a parallel download operation.
|
||||
type DownloadOptions struct {
|
||||
Blobs []Blob // Blobs to download
|
||||
BaseURL string // Registry base URL
|
||||
DestDir string // Destination directory for blobs
|
||||
Repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
Concurrency int // Max parallel downloads (default 64)
|
||||
Progress func(completed, total int64) // Progress callback (optional)
|
||||
Client *http.Client // HTTP client (optional, uses default)
|
||||
Token string // Auth token (optional)
|
||||
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
|
||||
Logger *slog.Logger // Optional structured logger
|
||||
UserAgent string // User-Agent header (optional, has default)
|
||||
StallTimeout time.Duration // Timeout for stall detection (default 10s)
|
||||
}
|
||||
|
||||
// UploadOptions configures a parallel upload operation.
|
||||
type UploadOptions struct {
|
||||
Blobs []Blob // Blobs to upload
|
||||
BaseURL string // Registry base URL
|
||||
SrcDir string // Source directory containing blobs
|
||||
Concurrency int // Max parallel uploads (default 32)
|
||||
Progress func(completed, total int64) // Progress callback (optional)
|
||||
Client *http.Client // HTTP client (optional, uses default)
|
||||
Token string // Auth token (optional)
|
||||
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
|
||||
Logger *slog.Logger // Optional structured logger
|
||||
UserAgent string // User-Agent header (optional, has default)
|
||||
|
||||
// Manifest fields (optional) - if set, manifest is pushed after all blobs complete
|
||||
Manifest []byte // Raw manifest JSON to push
|
||||
ManifestRef string // Tag or digest for the manifest (e.g., "latest", "sha256:...")
|
||||
Repository string // Repository path for manifest URL (e.g., "library/model")
|
||||
}
|
||||
|
||||
// AuthChallenge represents a parsed WWW-Authenticate challenge.
|
||||
type AuthChallenge struct {
|
||||
Realm string
|
||||
Service string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// Default concurrency limits and settings
|
||||
const (
|
||||
DefaultDownloadConcurrency = 64
|
||||
DefaultUploadConcurrency = 32
|
||||
maxRetries = 6
|
||||
defaultUserAgent = "ollama-transfer/1.0"
|
||||
)
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// defaultClient is a shared HTTP client with connection pooling.
|
||||
var defaultClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
// progressTracker aggregates progress across concurrent operations.
|
||||
type progressTracker struct {
|
||||
completed atomic.Int64
|
||||
total int64
|
||||
callback func(completed, total int64)
|
||||
}
|
||||
|
||||
func newProgressTracker(total int64, callback func(completed, total int64)) *progressTracker {
|
||||
return &progressTracker{
|
||||
total: total,
|
||||
callback: callback,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *progressTracker) add(n int64) {
|
||||
if p == nil || p.callback == nil {
|
||||
return
|
||||
}
|
||||
completed := p.completed.Add(n)
|
||||
p.callback(completed, p.total)
|
||||
}
|
||||
|
||||
// Download downloads blobs in parallel with streaming hash verification.
|
||||
func Download(ctx context.Context, opts DownloadOptions) error {
|
||||
return download(ctx, opts)
|
||||
}
|
||||
|
||||
// Upload uploads blobs in parallel.
|
||||
func Upload(ctx context.Context, opts UploadOptions) error {
|
||||
return upload(ctx, opts)
|
||||
}
|
||||
|
||||
// digestToPath converts sha256:abc123 to sha256-abc123
|
||||
func digestToPath(digest string) string {
|
||||
if len(digest) > 7 && digest[6] == ':' {
|
||||
return digest[:6] + "-" + digest[7:]
|
||||
}
|
||||
return digest
|
||||
}
|
||||
|
||||
// parseAuthChallenge parses a WWW-Authenticate header value.
|
||||
// Example: Bearer realm="https://auth.example.com",service="registry",scope="repository:foo:pull"
|
||||
func parseAuthChallenge(header string) AuthChallenge {
|
||||
header = strings.TrimPrefix(header, "Bearer ")
|
||||
|
||||
getValue := func(key string) string {
|
||||
startIdx := strings.Index(header, key+"=")
|
||||
if startIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
startIdx += len(key) + 1
|
||||
if startIdx >= len(header) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle quoted values
|
||||
if header[startIdx] == '"' {
|
||||
startIdx++
|
||||
endIdx := strings.Index(header[startIdx:], "\"")
|
||||
if endIdx == -1 {
|
||||
return header[startIdx:]
|
||||
}
|
||||
return header[startIdx : startIdx+endIdx]
|
||||
}
|
||||
|
||||
// Unquoted value - ends at comma or end of string
|
||||
endIdx := strings.Index(header[startIdx:], ",")
|
||||
if endIdx == -1 {
|
||||
return header[startIdx:]
|
||||
}
|
||||
return header[startIdx : startIdx+endIdx]
|
||||
}
|
||||
|
||||
return AuthChallenge{
|
||||
Realm: getValue("realm"),
|
||||
Service: getValue("service"),
|
||||
Scope: getValue("scope"),
|
||||
}
|
||||
}
|
||||
|
||||
// backoff returns a function that sleeps with exponential backoff.
|
||||
func backoff(ctx context.Context, attempt int, maxBackoff time.Duration) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// n^2 backoff with jitter
|
||||
d := min(time.Duration(attempt*attempt)*10*time.Millisecond, maxBackoff)
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
1777
x/imagegen/transfer/transfer_test.go
Normal file
1777
x/imagegen/transfer/transfer_test.go
Normal file
File diff suppressed because it is too large
Load Diff
346
x/imagegen/transfer/upload.go
Normal file
346
x/imagegen/transfer/upload.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type uploader struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
srcDir string
|
||||
repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
token *string
|
||||
getToken func(context.Context, AuthChallenge) (string, error)
|
||||
userAgent string
|
||||
progress *progressTracker
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, opts UploadOptions) error {
|
||||
if len(opts.Blobs) == 0 && len(opts.Manifest) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
u := &uploader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
srcDir: opts.SrcDir,
|
||||
repository: cmp.Or(opts.Repository, "library/_"),
|
||||
token: &token,
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
if len(opts.Blobs) > 0 {
|
||||
// Phase 1: Fast parallel HEAD checks to find which blobs need uploading
|
||||
needsUpload := make([]bool, len(opts.Blobs))
|
||||
{
|
||||
sem := semaphore.NewWeighted(128) // High concurrency for HEAD checks
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for i, blob := range opts.Blobs {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(gctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
exists, err := u.exists(gctx, blob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
needsUpload[i] = true
|
||||
} else if u.logger != nil {
|
||||
u.logger.Debug("blob exists", "digest", blob.Digest)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Filter to only blobs that need uploading
|
||||
var toUpload []Blob
|
||||
var total int64
|
||||
for i, blob := range opts.Blobs {
|
||||
if needsUpload[i] {
|
||||
toUpload = append(toUpload, blob)
|
||||
total += blob.Size
|
||||
}
|
||||
}
|
||||
|
||||
if len(toUpload) == 0 {
|
||||
if u.logger != nil {
|
||||
u.logger.Debug("all blobs exist, nothing to upload")
|
||||
}
|
||||
} else {
|
||||
// Phase 2: Upload blobs that don't exist
|
||||
u.progress = newProgressTracker(total, opts.Progress)
|
||||
concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency)
|
||||
sem := semaphore.NewWeighted(int64(concurrency))
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for _, blob := range toUpload {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(gctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
return u.upload(gctx, blob)
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" {
|
||||
return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *uploader) upload(ctx context.Context, blob Blob) error {
|
||||
var lastErr error
|
||||
var n int64
|
||||
|
||||
for attempt := range maxRetries {
|
||||
if attempt > 0 {
|
||||
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
n, err = u.uploadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
|
||||
u.progress.add(-n)
|
||||
lastErr = err
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
||||
}
|
||||
|
||||
func (u *uploader) uploadOnce(ctx context.Context, blob Blob) (int64, error) {
|
||||
if u.logger != nil {
|
||||
u.logger.Debug("uploading blob", "digest", blob.Digest, "size", blob.Size)
|
||||
}
|
||||
|
||||
// Init upload
|
||||
uploadURL, err := u.initUpload(ctx, blob)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Open file
|
||||
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// PUT blob
|
||||
return u.put(ctx, uploadURL, f, blob.Size)
|
||||
}
|
||||
|
||||
func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodHead, fmt.Sprintf("%s/v2/%s/blobs/%s", u.baseURL, u.repository, blob.Digest), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return u.exists(ctx, blob)
|
||||
}
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (u *uploader) initUpload(ctx context.Context, blob Blob) (string, error) {
|
||||
endpoint, _ := url.Parse(fmt.Sprintf("%s/v2/%s/blobs/uploads/", u.baseURL, u.repository))
|
||||
q := endpoint.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
endpoint.RawQuery = q.Encode()
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.initUpload(ctx, blob)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
return "", fmt.Errorf("init: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
loc := resp.Header.Get("Docker-Upload-Location")
|
||||
if loc == "" {
|
||||
loc = resp.Header.Get("Location")
|
||||
}
|
||||
if loc == "" {
|
||||
return "", fmt.Errorf("no upload location")
|
||||
}
|
||||
|
||||
locURL, _ := url.Parse(loc)
|
||||
if !locURL.IsAbs() {
|
||||
base, _ := url.Parse(u.baseURL)
|
||||
locURL = base.ResolveReference(locURL)
|
||||
}
|
||||
q = locURL.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
locURL.RawQuery = q.Encode()
|
||||
|
||||
return locURL.String(), nil
|
||||
}
|
||||
|
||||
func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size int64) (int64, error) {
|
||||
pr := &progressReader{reader: f, tracker: u.progress}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, pr)
|
||||
req.ContentLength = size
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return pr.n, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle auth retry
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return pr.n, err
|
||||
}
|
||||
f.Seek(0, 0)
|
||||
u.progress.add(-pr.n)
|
||||
return u.put(ctx, uploadURL, f, size)
|
||||
}
|
||||
|
||||
// Handle redirect to CDN
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
loc, _ := resp.Location()
|
||||
f.Seek(0, 0)
|
||||
u.progress.add(-pr.n)
|
||||
pr2 := &progressReader{reader: f, tracker: u.progress}
|
||||
|
||||
req2, _ := http.NewRequestWithContext(ctx, http.MethodPut, loc.String(), pr2)
|
||||
req2.ContentLength = size
|
||||
req2.Header.Set("Content-Type", "application/octet-stream")
|
||||
req2.Header.Set("User-Agent", u.userAgent)
|
||||
|
||||
resp2, err := u.client.Do(req2)
|
||||
if err != nil {
|
||||
return pr2.n, err
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
|
||||
if resp2.StatusCode != http.StatusCreated && resp2.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp2.Body)
|
||||
return pr2.n, fmt.Errorf("status %d: %s", resp2.StatusCode, body)
|
||||
}
|
||||
return pr2.n, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return pr.n, fmt.Errorf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
return pr.n, nil
|
||||
}
|
||||
|
||||
func (u *uploader) pushManifest(ctx context.Context, repo, ref string, manifest []byte) error {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("%s/v2/%s/manifests/%s", u.baseURL, repo, ref), bytes.NewReader(manifest))
|
||||
req.Header.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.pushManifest(ctx, repo, ref, manifest)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
tracker *progressTracker
|
||||
n int64
|
||||
}
|
||||
|
||||
func (r *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.n += int64(n)
|
||||
r.tracker.add(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
116
x/imagegen/weights.go
Normal file
116
x/imagegen/weights.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ManifestWeights provides fast weight loading from tensor blobs.
|
||||
// Uses native mmap loading with synthetic safetensors headers for zero-copy.
|
||||
type ManifestWeights struct {
|
||||
manifest *ModelManifest
|
||||
component string
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
|
||||
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
|
||||
layers := manifest.GetTensorLayers(component)
|
||||
if len(layers) == 0 {
|
||||
return nil, fmt.Errorf("no tensor layers found for component %q", component)
|
||||
}
|
||||
|
||||
// Strip component prefix from tensor names for model loading
|
||||
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
|
||||
prefix := component + "/"
|
||||
tensors := make(map[string]ManifestLayer, len(layers))
|
||||
for _, layer := range layers {
|
||||
tensorName := strings.TrimPrefix(layer.Name, prefix)
|
||||
tensors[tensorName] = layer
|
||||
}
|
||||
|
||||
return &ManifestWeights{
|
||||
manifest: manifest,
|
||||
component: component,
|
||||
tensors: tensors,
|
||||
cache: make(map[string]*mlx.Array),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Load loads all tensor blobs using native mmap (zero-copy).
|
||||
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
||||
// If dtype is non-zero, tensors are converted to the specified dtype.
|
||||
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
for name, layer := range mw.tensors {
|
||||
path := mw.manifest.BlobPath(layer.Digest)
|
||||
|
||||
// Load blob as safetensors (native mmap, zero-copy)
|
||||
sf, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Blob contains single tensor named "data"
|
||||
arr := sf.Get("data")
|
||||
if arr == nil {
|
||||
sf.Free()
|
||||
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
||||
}
|
||||
|
||||
// Convert dtype if needed
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
}
|
||||
// ALWAYS make a contiguous copy to ensure independence from mmap
|
||||
arr = mlx.Contiguous(arr)
|
||||
mlx.Eval(arr)
|
||||
mw.cache[name] = arr
|
||||
sf.Free() // Safe to free - arr is now an independent copy
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTensor returns a tensor from cache. Call Load() first.
|
||||
func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
if mw.cache == nil {
|
||||
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
||||
}
|
||||
arr, ok := mw.cache[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names in sorted order.
|
||||
func (mw *ManifestWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensors))
|
||||
for name := range mw.tensors {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensors[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
func (mw *ManifestWeights) ReleaseAll() {
|
||||
for _, sf := range mw.nativeCache {
|
||||
sf.Free()
|
||||
}
|
||||
mw.nativeCache = nil
|
||||
mw.cache = nil
|
||||
}
|
||||
@@ -38,6 +38,32 @@ func (r *Registry) Register(tool Tool) {
|
||||
r.tools[tool.Name()] = tool
|
||||
}
|
||||
|
||||
// Unregister removes a tool from the registry by name.
|
||||
func (r *Registry) Unregister(name string) {
|
||||
delete(r.tools, name)
|
||||
}
|
||||
|
||||
// Has checks if a tool with the given name is registered.
|
||||
func (r *Registry) Has(name string) bool {
|
||||
_, ok := r.tools[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// RegisterBash adds the bash tool to the registry.
|
||||
func (r *Registry) RegisterBash() {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
|
||||
// RegisterWebSearch adds the web search tool to the registry.
|
||||
func (r *Registry) RegisterWebSearch() {
|
||||
r.Register(&WebSearchTool{})
|
||||
}
|
||||
|
||||
// RegisterWebFetch adds the web fetch tool to the registry.
|
||||
func (r *Registry) RegisterWebFetch() {
|
||||
r.Register(&WebFetchTool{})
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
@@ -94,9 +120,10 @@ func (r *Registry) Count() int {
|
||||
// - OLLAMA_AGENT_DISABLE_BASH=1 disables bash
|
||||
func DefaultRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
|
||||
r.Register(&WebSearchTool{})
|
||||
}
|
||||
// TODO(parthsareen): re-enable web search once it's ready for release
|
||||
// if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
|
||||
// r.Register(&WebSearchTool{})
|
||||
// }
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_BASH") == "" {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
|
||||
@@ -93,19 +93,14 @@ func TestRegistry_Execute(t *testing.T) {
|
||||
func TestDefaultRegistry(t *testing.T) {
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool in default registry, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Error("expected bash tool in default registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in default registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry_DisableWebsearch(t *testing.T) {
|
||||
@@ -133,18 +128,8 @@ func TestDefaultRegistry_DisableBash(t *testing.T) {
|
||||
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool with bash disabled, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("bash")
|
||||
if ok {
|
||||
t.Error("expected bash to be disabled")
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0 tools with bash disabled, got %d", r.Count())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,3 +177,47 @@ func TestWebSearchTool_Schema(t *testing.T) {
|
||||
t.Error("expected 'query' property in schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Unregister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Unregister("bash")
|
||||
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0 tools after unregister, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if ok {
|
||||
t.Error("expected bash tool to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Has(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if r.Has("bash") {
|
||||
t.Error("expected Has to return false for unregistered tool")
|
||||
}
|
||||
|
||||
r.Register(&BashTool{})
|
||||
|
||||
if !r.Has("bash") {
|
||||
t.Error("expected Has to return true for registered tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_RegisterBash(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.RegisterBash()
|
||||
|
||||
if !r.Has("bash") {
|
||||
t.Error("expected bash tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
162
x/tools/webfetch.go
Normal file
162
x/tools/webfetch.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
webFetchAPI = "https://ollama.com/api/web_fetch"
|
||||
webFetchTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
|
||||
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
|
||||
|
||||
// WebFetchTool implements web page fetching using Ollama's hosted API.
|
||||
type WebFetchTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (w *WebFetchTool) Name() string {
|
||||
return "web_fetch"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (w *WebFetchTool) Description() string {
|
||||
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (w *WebFetchTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("url", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The URL to fetch and extract content from",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: w.Name(),
|
||||
Description: w.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// webFetchRequest is the request body for the web fetch API.
|
||||
type webFetchRequest struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// webFetchResponse is the response from the web fetch API.
|
||||
type webFetchResponse struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Links []string `json:"links,omitempty"`
|
||||
}
|
||||
|
||||
// Execute fetches content from a web page.
|
||||
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
|
||||
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return "", fmt.Errorf("url parameter is required")
|
||||
}
|
||||
|
||||
// Validate URL
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Prepare request
|
||||
reqBody := webFetchRequest{
|
||||
URL: urlStr,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Parse URL and add timestamp for signing
|
||||
fetchURL, err := url.Parse(webFetchAPI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing fetch URL: %w", err)
|
||||
}
|
||||
|
||||
q := fetchURL.Query()
|
||||
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
fetchURL.RawQuery = q.Encode()
|
||||
|
||||
// Sign the request using Ollama key (~/.ollama/id_ed25519)
|
||||
ctx := context.Background()
|
||||
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
|
||||
signature, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: webFetchTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return "", ErrWebFetchAuthRequired
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var fetchResp webFetchResponse
|
||||
if err := json.Unmarshal(body, &fetchResp); err != nil {
|
||||
return "", fmt.Errorf("parsing response: %w", err)
|
||||
}
|
||||
|
||||
// Format result
|
||||
var sb strings.Builder
|
||||
if fetchResp.Title != "" {
|
||||
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
|
||||
}
|
||||
|
||||
if fetchResp.Content != "" {
|
||||
sb.WriteString("Content:\n")
|
||||
sb.WriteString(fetchResp.Content)
|
||||
} else {
|
||||
sb.WriteString("No content could be extracted from the page.")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user