Compare commits

..

19 Commits

Author SHA1 Message Date
jmorganca
6b4fb18235 undo readme 2026-01-11 23:06:01 -08:00
jmorganca
73ad90bf08 fix 2026-01-11 22:49:30 -08:00
jmorganca
47b8def07e x/imagegen: add TeaCache and FP8 quantization support
TeaCache:
- Timestep embedding similarity caching for diffusion models
- Polynomial rescaling with configurable thresholds
- Reduces transformer forward passes by ~30-50%

FP8 quantization:
- Support for FP8 quantized models (8-bit weights with scales)
- QuantizedMatmul on Metal, Dequantize on CUDA
- Client-side quantization via ollama create --quantize fp8

Other improvements:
- Fix Show API for image generation models
- Server properly returns model info (architecture, parameters, quantization)
- Memory allocation optimizations
- CLI improvements for image generation
2026-01-11 22:48:20 -08:00
Patrick Devine
7e2496e88e Fix cmake install command in README (#13678)
Update installation command for MLX component in README.
2026-01-11 13:16:42 -08:00
WhatToPutHere
5b84e29882 docs: fix troubleshooting page (#13674)
Updated the link in the log output description to point to the correct troubleshooting guide format.
2026-01-11 00:58:07 -08:00
Jeffrey Morgan
7cc2a653f2 dockerfile: remove unused COPY command (#13664) 2026-01-09 23:07:15 -08:00
Jeffrey Morgan
2584940016 Add z-image image generation prototype (#13659) 2026-01-09 21:09:46 -08:00
Michael
c6d4c0c7f2 Documentation edits made through Mintlify web editor 2026-01-09 21:29:03 -05:00
Parth Sareen
1ef4241727 x: request access for all commands, add welcome message (#13662) 2026-01-09 18:20:39 -08:00
Parth Sareen
68fafd3002 x: improve approval selector with clearer labels (#13663) 2026-01-09 17:08:12 -08:00
Parth Sareen
2b2cda7a2b api: implement anthropic api (#13600)
* api: add Anthropic Messages API compatibility layer

Add middleware to support the Anthropic Messages API format at /v1/messages.
This enables tools like Claude Code to work with Ollama local and cloud models through the
Anthropic API interface.
2026-01-09 11:53:36 -08:00
Daniel Hiltgen
3cfe9fe146 docker: add missing deps (#13654)
The new MLX library has extra dependencies.
2026-01-09 07:34:40 -08:00
Parth Sareen
a23b559b4c x: disable web search tool registration (#13656) 2026-01-09 01:42:20 -08:00
Daniel Hiltgen
33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00
Daniel Hiltgen
34d0c55ea5 Linux: switch to zstd compression (#13651)
With the upcoming addition of MLX, the linux bundle will exceed the
maximum github artifact size of 2G.  This change will bring the size
back down.

The install.sh changes support backwards compatibility for prior versions
thus should be safe to merge concurrently with this change.
2026-01-08 15:47:32 -08:00
Parth Sareen
53a5a9e9ae x: redesign agent UI with minimal styling (#13650) 2026-01-08 15:40:07 -08:00
Parth Sareen
e30e08a7d6 x: remove Ctrl+O tool output expansion feature (#13640) 2026-01-07 15:34:08 -08:00
Parth Sareen
12e2b3514a x: agent loop ux improvements (#13635) 2026-01-07 01:27:15 -08:00
Devon Rifkin
626af2d809 template: fix args-as-json rendering (#13636)
In #13525, I accidentally broke templates' ability to automatically
render tool call function arguments as JSON.

We do need these to be proper maps because we need templates to be able
to call range, which can't be done on custom types.
2026-01-06 18:33:57 -08:00
198 changed files with 38890 additions and 4446 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -68,6 +68,7 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -392,13 +393,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tgz
*.tar.zst
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -531,7 +532,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -147,14 +163,48 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,6 +83,28 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -140,6 +162,21 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,8 +131,36 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -153,6 +181,7 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -171,7 +200,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

778
anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,778 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

953
anthropic/anthropic_test.go Normal file
View File

@@ -0,0 +1,953 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 8 * format.MegaByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -19,12 +19,6 @@ import (
"github.com/ollama/ollama/types/model"
)
// SkillRef is an alias for model.SkillRef representing a skill reference.
type SkillRef = model.SkillRef
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
type MCPRef = model.MCPRef
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
@@ -696,18 +690,6 @@ type CreateRequest struct {
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Skills is a list of skill references for the agent (local paths or registry refs)
Skills []SkillRef `json:"skills,omitempty"`
// MCPs is a list of MCP server references for the agent
MCPs []MCPRef `json:"mcps,omitempty"`
// AgentType defines the type of agent (e.g., "conversational", "task-based")
AgentType string `json:"agent_type,omitempty"`
// Entrypoint specifies an external command to run instead of the built-in chat loop
Entrypoint string `json:"entrypoint,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -759,10 +741,6 @@ type ShowResponse struct {
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
Skills []SkillRef `json:"skills,omitempty"`
MCPs []MCPRef `json:"mcps,omitempty"`
AgentType string `json:"agent_type,omitempty"`
Entrypoint string `json:"entrypoint,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].

View File

@@ -1,402 +0,0 @@
package cmd
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// TestToolMessage verifies that tool messages are constructed correctly
// with ToolName and ToolCallID preserved from the tool call.
func TestToolMessage(t *testing.T) {
tests := []struct {
name string
call api.ToolCall
content string
expected api.Message
}{
{
name: "basic tool message with ID",
call: api.ToolCall{
ID: "call_abc123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
content: "Sunny, 22°C",
expected: api.Message{
Role: "tool",
Content: "Sunny, 22°C",
ToolName: "get_weather",
ToolCallID: "call_abc123",
},
},
{
name: "tool message without ID",
call: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: api.ToolCallFunctionArguments{
"expression": "2+2",
},
},
},
content: "4",
expected: api.Message{
Role: "tool",
Content: "4",
ToolName: "calculate",
// ToolCallID should be empty when call.ID is empty
},
},
{
name: "MCP tool message",
call: api.ToolCall{
ID: "call_mcp123",
Function: api.ToolCallFunction{
Name: "mcp_websearch_search",
Arguments: api.ToolCallFunctionArguments{
"query": "ollama agents",
},
},
},
content: "Found 10 results",
expected: api.Message{
Role: "tool",
Content: "Found 10 results",
ToolName: "mcp_websearch_search",
ToolCallID: "call_mcp123",
},
},
{
name: "skill tool message",
call: api.ToolCall{
ID: "call_skill456",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "calculator",
"command": "python scripts/calc.py 2+2",
},
},
},
content: "Result: 4",
expected: api.Message{
Role: "tool",
Content: "Result: 4",
ToolName: "run_skill_script",
ToolCallID: "call_skill456",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toolMessage(tt.call, tt.content)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("toolMessage() mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestAssistantMessageWithThinking verifies that assistant messages
// in the tool loop should include thinking content.
func TestAssistantMessageConstruction(t *testing.T) {
tests := []struct {
name string
content string
thinking string
toolCalls []api.ToolCall
expectedMsg api.Message
}{
{
name: "assistant with thinking and tool calls",
content: "",
thinking: "I need to check the weather for Paris.",
toolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for Paris.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
},
},
{
name: "assistant with content, thinking, and tool calls",
content: "Let me check that for you.",
thinking: "User wants weather info.",
toolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "Let me check that for you.",
Thinking: "User wants weather info.",
ToolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
},
},
{
name: "assistant with multiple tool calls",
content: "",
thinking: "I'll check both cities.",
toolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I'll check both cities.",
ToolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the assistant message construction as done in chat()
assistantMsg := api.Message{
Role: "assistant",
Content: tt.content,
Thinking: tt.thinking,
ToolCalls: tt.toolCalls,
}
if diff := cmp.Diff(tt.expectedMsg, assistantMsg); diff != "" {
t.Errorf("assistant message mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestMessageStitchingOrder verifies that messages in a tool loop
// are stitched in the correct order:
// 1. User message
// 2. Assistant message with tool calls (and thinking)
// 3. Tool result messages (one per tool call, in order)
// 4. Next assistant response
func TestMessageStitchingOrder(t *testing.T) {
// Simulate a complete tool loop conversation
messages := []api.Message{
// Initial user message
{Role: "user", Content: "What's the weather in Paris and London?"},
// Assistant's first response with tool calls
{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for both cities.",
ToolCalls: []api.ToolCall{
{ID: "call_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
{ID: "call_2", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
},
},
// Tool results (in order matching tool calls)
{Role: "tool", Content: "Sunny, 22°C", ToolName: "get_weather", ToolCallID: "call_1"},
{Role: "tool", Content: "Rainy, 15°C", ToolName: "get_weather", ToolCallID: "call_2"},
// Final assistant response
{Role: "assistant", Content: "Paris is sunny at 22°C, and London is rainy at 15°C.", Thinking: "Got the data, now summarizing."},
}
// Verify structure
expectedRoles := []string{"user", "assistant", "tool", "tool", "assistant"}
for i, msg := range messages {
if msg.Role != expectedRoles[i] {
t.Errorf("message %d: expected role %q, got %q", i, expectedRoles[i], msg.Role)
}
}
// Verify tool results match tool calls in order
assistantWithTools := messages[1]
toolResults := []api.Message{messages[2], messages[3]}
if len(toolResults) != len(assistantWithTools.ToolCalls) {
t.Errorf("expected %d tool results for %d tool calls", len(assistantWithTools.ToolCalls), len(toolResults))
}
for i, result := range toolResults {
expectedToolCallID := assistantWithTools.ToolCalls[i].ID
if result.ToolCallID != expectedToolCallID {
t.Errorf("tool result %d: expected ToolCallID %q, got %q", i, expectedToolCallID, result.ToolCallID)
}
expectedToolName := assistantWithTools.ToolCalls[i].Function.Name
if result.ToolName != expectedToolName {
t.Errorf("tool result %d: expected ToolName %q, got %q", i, expectedToolName, result.ToolName)
}
}
// Verify thinking is present in assistant messages
if messages[1].Thinking == "" {
t.Error("first assistant message should have thinking content")
}
if messages[4].Thinking == "" {
t.Error("final assistant message should have thinking content")
}
}
// TestMultiTurnToolLoop verifies message stitching across multiple
// tool call iterations.
func TestMultiTurnToolLoop(t *testing.T) {
messages := []api.Message{
{Role: "user", Content: "What's 2+2 and also what's the weather in Paris?"},
// First tool call: calculate
{
Role: "assistant",
Thinking: "I'll start with the calculation.",
ToolCalls: []api.ToolCall{
{ID: "calc_1", Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expr": "2+2"}}},
},
},
{Role: "tool", Content: "4", ToolName: "calculate", ToolCallID: "calc_1"},
// Second tool call: weather
{
Role: "assistant",
Thinking: "Got the calculation. Now checking weather.",
ToolCalls: []api.ToolCall{
{ID: "weather_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
},
},
{Role: "tool", Content: "Sunny, 20°C", ToolName: "get_weather", ToolCallID: "weather_1"},
// Final response
{Role: "assistant", Content: "2+2 equals 4, and Paris is sunny at 20°C."},
}
// Count message types
roleCounts := map[string]int{}
for _, msg := range messages {
roleCounts[msg.Role]++
}
if roleCounts["user"] != 1 {
t.Errorf("expected 1 user message, got %d", roleCounts["user"])
}
if roleCounts["assistant"] != 3 {
t.Errorf("expected 3 assistant messages, got %d", roleCounts["assistant"])
}
if roleCounts["tool"] != 2 {
t.Errorf("expected 2 tool messages, got %d", roleCounts["tool"])
}
// Verify each tool message follows an assistant with matching tool call
for i, msg := range messages {
if msg.Role == "tool" {
// Find preceding assistant message with tool calls
var precedingAssistant *api.Message
for j := i - 1; j >= 0; j-- {
if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 {
precedingAssistant = &messages[j]
break
}
}
if precedingAssistant == nil {
t.Errorf("tool message at index %d has no preceding assistant with tool calls", i)
continue
}
// Verify tool result matches one of the tool calls
found := false
for _, tc := range precedingAssistant.ToolCalls {
if tc.ID == msg.ToolCallID {
found = true
break
}
}
if !found {
t.Errorf("tool message at index %d has ToolCallID %q not found in preceding tool calls", i, msg.ToolCallID)
}
}
}
}
// TestSkillCatalogRunToolCallPreservesFields tests that skill catalog
// returns tool messages with correct fields.
func TestSkillCatalogToolMessageFields(t *testing.T) {
// Create a minimal test for toolMessage function
call := api.ToolCall{
ID: "test_id_123",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "test-skill",
"command": "echo hello",
},
},
}
msg := toolMessage(call, "hello")
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.Content != "hello" {
t.Errorf("expected content 'hello', got %q", msg.Content)
}
if msg.ToolName != "run_skill_script" {
t.Errorf("expected ToolName 'run_skill_script', got %q", msg.ToolName)
}
if msg.ToolCallID != "test_id_123" {
t.Errorf("expected ToolCallID 'test_id_123', got %q", msg.ToolCallID)
}
}

View File

@@ -15,7 +15,6 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -47,6 +46,8 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -97,6 +98,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -458,6 +464,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -496,16 +503,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
// Check if this is an agent
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
if isAgent {
opts.IsAgent = true
opts.AgentType = info.AgentType
opts.Skills = info.Skills
opts.MCPs = info.MCPs
opts.Entrypoint = info.Entrypoint
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -529,12 +526,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
// If agent has entrypoint, run it instead of chat loop
if opts.Entrypoint != "" {
return runEntrypoint(cmd, opts)
}
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -562,69 +564,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
}
return generateInteractive(cmd, opts)
}
// For agents, use chat API even in non-interactive mode to support tools
if opts.IsAgent {
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
_, err := chat(cmd, opts)
return err
}
return generate(cmd, opts)
}
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
entrypoint := opts.Entrypoint
// Check if entrypoint contains $PROMPT placeholder
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
if hasPlaceholder && opts.Prompt != "" {
// Replace $PROMPT with the actual prompt
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
} else if hasPlaceholder {
// No prompt provided but placeholder exists - remove placeholder
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
}
// Parse entrypoint into command and args
parts := strings.Fields(entrypoint)
if len(parts) == 0 {
return fmt.Errorf("empty entrypoint")
}
command := parts[0]
args := parts[1:]
// If user provided a prompt and no placeholder was used, append it as argument
if opts.Prompt != "" && !hasPlaceholder {
args = append(args, opts.Prompt)
}
// Look up command in PATH
execPath, err := exec.LookPath(command)
if err != nil {
return fmt.Errorf("entrypoint command not found: %s", command)
}
// Create subprocess
proc := exec.Command(execPath, args...)
proc.Stdin = os.Stdin
proc.Stdout = os.Stdout
proc.Stderr = os.Stderr
// Run and wait
return proc.Run()
}
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -723,7 +672,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -984,96 +937,47 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
fmt.Fprintln(w)
}
// Only show Model section if there's actual model info (not for entrypoint-only agents)
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
if hasModelInfo {
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
}
// Display agent information if this is an agent
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
tableRender("Agent", func() (rows [][]string) {
if resp.AgentType != "" {
rows = append(rows, []string{"", "type", resp.AgentType})
}
if resp.Entrypoint != "" {
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
}
if len(resp.Skills) > 0 {
for i, skill := range resp.Skills {
label := "skill"
if i > 0 {
label = ""
}
// Show skill name or digest
skillDisplay := skill.Name
if skillDisplay == "" && skill.Digest != "" {
skillDisplay = skill.Digest[:12] + "..."
}
rows = append(rows, []string{"", label, skillDisplay})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
if len(resp.MCPs) > 0 {
for i, mcp := range resp.MCPs {
label := "mcp"
if i > 0 {
label = ""
}
// Show MCP name and command
mcpDisplay := mcp.Name
if mcp.Command != "" {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
mcpDisplay += " (" + cmdLine + ")"
}
rows = append(rows, []string{"", label, mcpDisplay})
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
return
})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
@@ -1315,11 +1219,6 @@ type runOptions struct {
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
IsAgent bool
AgentType string
Skills []api.SkillRef
MCPs []api.MCPRef
Entrypoint string
}
func (r runOptions) Copy() runOptions {
@@ -1349,12 +1248,6 @@ func (r runOptions) Copy() runOptions {
think = &cThink
}
var skills []api.SkillRef
if r.Skills != nil {
skills = make([]api.SkillRef, len(r.Skills))
copy(skills, r.Skills)
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
@@ -1370,9 +1263,6 @@ func (r runOptions) Copy() runOptions {
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
IsAgent: r.IsAgent,
AgentType: r.AgentType,
Skills: skills,
}
}
@@ -1456,65 +1346,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return nil, err
}
// Load skills for agents
var skillsCatalog *skillCatalog
if opts.IsAgent && len(opts.Skills) > 0 {
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
if err != nil {
return nil, fmt.Errorf("failed to load skills: %w", err)
}
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
var skillNames []string
for _, s := range skillsCatalog.Skills {
skillNames = append(skillNames, s.Name)
}
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
}
}
// Load MCP servers for agents (from opts and global config)
var mcpMgr *mcpManager
allMCPs := opts.MCPs
// Load global MCPs from ~/.ollama/mcp.json
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
for name, srv := range globalConfig.MCPServers {
// Skip disabled MCPs
if srv.Disabled {
continue
}
// Check if already in opts.MCPs (model takes precedence)
found := false
for _, m := range opts.MCPs {
if m.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
}
}
if len(allMCPs) > 0 {
mcpMgr = newMCPManager()
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
}
if mcpMgr.ToolCount() > 0 {
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
}
defer mcpMgr.Shutdown()
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
@@ -1538,7 +1369,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var fullResponse strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
role := "assistant"
@@ -1579,13 +1409,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
if response.Message.ToolCalls != nil {
toolCalls := response.Message.ToolCalls
if len(toolCalls) > 0 {
if skillsCatalog != nil || mcpMgr != nil {
// Store tool calls for execution after response is complete
pendingToolCalls = append(pendingToolCalls, toolCalls...)
} else {
// No skills catalog or MCP, just display tool calls
fmt.Print(renderToolCalls(toolCalls, false))
}
fmt.Print(renderToolCalls(toolCalls, false))
}
}
@@ -1598,161 +1422,31 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
opts.Format = `"` + opts.Format + `"`
}
// Prepare messages with agent-specific system prompt
messages := opts.Messages
if skillsCatalog != nil {
// Add skills system prompt as the first system message
skillsPrompt := skillsCatalog.SystemPrompt()
if skillsPrompt != "" {
// Insert skills prompt at the beginning, or append to existing system message
if len(messages) > 0 && messages[0].Role == "system" {
// Append to existing system message
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
} else {
// Insert new system message at the beginning
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
messages = append([]api.Message{systemMsg}, messages...)
}
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// Add tools for agents (combine skills and MCP tools)
var allTools api.Tools
if skillsCatalog != nil {
allTools = append(allTools, skillsCatalog.Tools()...)
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
if mcpMgr != nil {
allTools = append(allTools, mcpMgr.Tools()...)
}
if len(allTools) > 0 {
req.Tools = allTools
}
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err
}
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
break
}
// Execute tool calls and continue the conversation
fmt.Fprintf(os.Stderr, "\n")
// Add assistant's tool call message to history (include thinking for proper rendering)
assistantMsg := api.Message{
Role: "assistant",
Content: fullResponse.String(),
Thinking: thinkingContent.String(),
ToolCalls: pendingToolCalls,
}
messages = append(messages, assistantMsg)
// Execute each tool call and collect results
var toolResults []api.Message
for _, call := range pendingToolCalls {
// Show what's being executed
switch call.Function.Name {
case "run_skill_script":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
commandVal, _ := call.Function.Arguments.Get("command")
command, _ := commandVal.(string)
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
case "read_skill_file":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
pathVal, _ := call.Function.Arguments.Get("path")
path, _ := pathVal.(string)
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
default:
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
}
var result api.Message
var handled bool
var err error
// Try skill catalog first
if skillsCatalog != nil {
result, handled, err = skillsCatalog.RunToolCall(call)
}
// If not handled by skills, try MCP
if !handled && mcpMgr != nil {
result, handled, err = mcpMgr.RunToolCall(call)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
// Add error result
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
})
continue
}
if !handled {
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
})
continue
}
// Display tool output
if result.Content != "" {
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
}
// Add tool result to messages (preserves ToolName, ToolCallID from result)
toolResults = append(toolResults, result)
}
// Add tool results to message history
messages = append(messages, toolResults...)
fmt.Fprintf(os.Stderr, "\n")
// Reset state for next iteration
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
// Start new progress spinner for next API call
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
return nil, err
}
if len(opts.Messages) > 0 {
@@ -2091,6 +1785,10 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -2245,8 +1943,6 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
NewSkillCommand(),
NewMCPCommand(),
)
return rootCmd

View File

@@ -34,9 +34,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /skills Show available skills")
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
@@ -447,411 +444,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usageShow()
}
case strings.HasPrefix(line, "/skill "):
args := strings.Fields(line)
if len(args) < 2 {
fmt.Fprintln(os.Stderr, "Usage:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current skills")
continue
}
switch args[1] {
case "add":
if len(args) < 3 {
fmt.Println("Usage: /skill add <path>")
continue
}
skillPath := args[2]
// Expand ~ to home directory
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Error expanding path: %v\n", err)
continue
}
skillPath = filepath.Join(home, skillPath[1:])
}
// Make absolute
absPath, err := filepath.Abs(skillPath)
if err != nil {
fmt.Printf("Error resolving path: %v\n", err)
continue
}
// Verify SKILL.md exists
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
continue
}
// Extract skill name from SKILL.md
content, err := os.ReadFile(skillMdPath)
if err != nil {
fmt.Printf("Error reading SKILL.md: %v\n", err)
continue
}
skillName, _ := extractSkillMetadata(string(content))
if skillName == "" {
skillName = filepath.Base(absPath)
}
// Check if already added
for _, s := range opts.Skills {
if s.Name == skillName {
fmt.Printf("Skill '%s' is already loaded\n", skillName)
continue
}
}
// Add to skills (using path as Name, no digest for local skills)
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
opts.IsAgent = true // Enable agent mode if not already
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /skill remove <name>")
continue
}
skillName := args[2]
found := false
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
for _, s := range opts.Skills {
// Match by name or by path basename
name := s.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name)
}
if name == skillName || s.Name == skillName {
found = true
fmt.Printf("Removed skill '%s'\n", skillName)
} else {
newSkills = append(newSkills, s)
}
}
if !found {
fmt.Printf("Skill '%s' not found\n", skillName)
} else {
opts.Skills = newSkills
}
case "list", "ls":
if len(opts.Skills) == 0 {
fmt.Println("No skills loaded in this session.")
} else {
fmt.Println("Skills loaded in this session:")
for _, skill := range opts.Skills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
// For local paths, show basename
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (local: " + skill.Name + ")"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
default:
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
}
continue
case strings.HasPrefix(line, "/skills"):
// Show skills from model (bundled) + session skills
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model skills with session skills
allSkills := make([]api.SkillRef, 0)
allSkills = append(allSkills, resp.Skills...)
// Add session skills that aren't already in model skills
for _, sessionSkill := range opts.Skills {
found := false
for _, modelSkill := range resp.Skills {
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
found = true
break
}
}
if !found {
allSkills = append(allSkills, sessionSkill)
}
}
if len(allSkills) == 0 {
fmt.Println("No skills available.")
} else {
fmt.Println("Available Skills:")
for _, skill := range allSkills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (session)"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
continue
case strings.HasPrefix(line, "/mcp"):
args := strings.Fields(line)
// If just "/mcp" with no args, show all MCP servers
if len(args) == 1 {
// Show MCPs from model (bundled) + global config
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model MCPs with global config MCPs
allMCPs := make([]api.MCPRef, 0)
allMCPs = append(allMCPs, resp.MCPs...)
// Load global config
globalConfig, _ := loadMCPConfig()
globalMCPNames := make(map[string]bool)
if globalConfig != nil {
for name, srv := range globalConfig.MCPServers {
// Check if already in model MCPs
found := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
globalMCPNames[name] = true
}
}
if len(allMCPs) == 0 {
fmt.Println("No MCP servers available.")
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
} else {
fmt.Println("Available MCP Servers:")
for _, mcp := range allMCPs {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
source := ""
disabled := ""
// Check if it's from model or global config
isFromModel := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == mcp.Name {
isFromModel = true
break
}
}
if isFromModel {
source = " (model)"
} else if globalMCPNames[mcp.Name] {
source = " (global)"
// Check if disabled
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
disabled = " [disabled]"
}
}
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
}
}
fmt.Println()
continue
}
switch args[1] {
case "add":
if len(args) < 4 {
fmt.Println("Usage: /mcp add <name> <command> [args...]")
continue
}
mcpName := args[2]
mcpCommand := args[3]
mcpArgs := args[4:]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
// Check if already exists
if _, exists := config.MCPServers[mcpName]; exists {
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
}
// Add to global config
config.MCPServers[mcpName] = MCPServerConfig{
Type: "stdio",
Command: mcpCommand,
Args: mcpArgs,
}
// Save config
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
cmdLine := mcpCommand
if len(mcpArgs) > 0 {
cmdLine += " " + strings.Join(mcpArgs, " ")
}
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
fmt.Println("Note: MCP server will be started on next message.")
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /mcp remove <name>")
continue
}
mcpName := args[2]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
if _, exists := config.MCPServers[mcpName]; !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
delete(config.MCPServers, mcpName)
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
fmt.Println("Note: Changes will take effect on next message.")
case "disable":
if len(args) < 3 {
fmt.Println("Usage: /mcp disable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if srv.Disabled {
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
continue
}
srv.Disabled = true
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
case "enable":
if len(args) < 3 {
fmt.Println("Usage: /mcp enable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if !srv.Disabled {
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
continue
}
srv.Disabled = false
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
default:
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
}
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
args := strings.Fields(line)
if len(args) > 1 {
@@ -860,20 +452,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usageSet()
case "show", "/show":
usageShow()
case "skill", "/skill":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
fmt.Fprintln(os.Stderr, "")
case "mcp", "/mcp":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
fmt.Fprintln(os.Stderr, "")
case "shortcut", "shortcuts":
usageShortcuts()
}

View File

@@ -1,570 +0,0 @@
package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
)
// SkillPushHandler handles the skill push command.
func SkillPushHandler(cmd *cobra.Command, args []string) error {
if len(args) != 2 {
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
}
name := args[0]
path := args[1]
// Expand path
if strings.HasPrefix(path, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("expanding home directory: %w", err)
}
path = filepath.Join(home, path[1:])
}
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("resolving path: %w", err)
}
// Validate skill directory
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
}
// Parse skill name (will set Kind="skill")
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Create skill layer
displayName := n.DisplayShortest()
status := fmt.Sprintf("Creating skill layer for %s", displayName)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
layer, err := server.CreateSkillLayer(absPath)
if err != nil {
return fmt.Errorf("creating skill layer: %w", err)
}
spinner.Stop()
// Create skill manifest
manifest, configLayer, err := createSkillManifest(absPath, layer)
if err != nil {
return fmt.Errorf("creating skill manifest: %w", err)
}
// Write manifest locally
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return fmt.Errorf("creating manifest directory: %w", err)
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return fmt.Errorf("marshaling manifest: %w", err)
}
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
return fmt.Errorf("writing manifest: %w", err)
}
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
// Push to registry
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
// For now, we'll use the existing push mechanism
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
req := &api.PushRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Push(context.Background(), req, fn); err != nil {
// If push fails, still show success for local creation
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
return nil
}
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
return nil
}
// SkillPullHandler handles the skill pull command.
func SkillPullHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
p := progress.NewProgress(os.Stderr)
defer p.Stop()
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
displayName := n.DisplayShortest()
req := &api.PullRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Pull(context.Background(), req, fn); err != nil {
return fmt.Errorf("pulling skill: %w", err)
}
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
return nil
}
// SkillListHandler handles the skill list command.
func SkillListHandler(cmd *cobra.Command, args []string) error {
skills, err := listLocalSkills()
if err != nil {
return fmt.Errorf("listing skills: %w", err)
}
if len(skills) == 0 {
fmt.Println("No skills installed")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
for _, skill := range skills {
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
skill.Namespace,
skill.Name,
skill.Tag,
format.HumanBytes(skill.Size),
format.HumanTime(skill.ModifiedAt, "Never"),
)
}
return w.Flush()
}
// SkillRemoveHandler handles the skill rm command.
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
}
for _, name := range args {
n := server.ParseSkillName(name)
if n.Model == "" {
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
continue
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
continue
}
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
continue
}
if err := os.Remove(manifestPath); err != nil {
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
continue
}
// Clean up empty parent directories
dir := filepath.Dir(manifestPath)
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
entries, _ := os.ReadDir(dir)
if len(entries) == 0 {
os.Remove(dir)
dir = filepath.Dir(dir)
} else {
break
}
}
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
}
return nil
}
// SkillShowHandler handles the skill show command.
func SkillShowHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
data, err := os.ReadFile(manifestPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("skill not found: %s", displayName)
}
return fmt.Errorf("reading manifest: %w", err)
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return fmt.Errorf("parsing manifest: %w", err)
}
fmt.Printf("Skill: %s\n\n", displayName)
fmt.Println("Layers:")
for _, layer := range manifest.Layers {
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
}
// Try to read and display SKILL.md content
if len(manifest.Layers) > 0 {
for _, layer := range manifest.Layers {
if layer.MediaType == server.MediaTypeSkill {
skillPath, err := server.GetSkillsPath(layer.Digest)
if err == nil {
skillMdPath := filepath.Join(skillPath, "SKILL.md")
if content, err := os.ReadFile(skillMdPath); err == nil {
fmt.Println("\nContent:")
fmt.Println(string(content))
}
}
}
}
}
return nil
}
// SkillInfo represents information about an installed skill.
type SkillInfo struct {
Namespace string
Name string
Tag string
Size int64
ModifiedAt time.Time
}
// listLocalSkills returns a list of locally installed skills.
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
// where kind is "skill".
func listLocalSkills() ([]SkillInfo, error) {
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
var skills []SkillInfo
// Walk through all registries
registries, err := os.ReadDir(manifestsPath)
if err != nil {
if os.IsNotExist(err) {
return skills, nil
}
return nil, err
}
for _, registry := range registries {
if !registry.IsDir() {
continue
}
// Walk namespaces
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
if err != nil {
continue
}
for _, namespace := range namespaces {
if !namespace.IsDir() {
continue
}
// Walk kinds looking for "skill"
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
if err != nil {
continue
}
for _, kind := range kinds {
if !kind.IsDir() {
continue
}
// Only process skill kind
if kind.Name() != server.SkillNamespace {
continue
}
// Walk skill names (model names)
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
if err != nil {
continue
}
for _, skillName := range skillNames {
if !skillName.IsDir() {
continue
}
// Walk tags
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
if err != nil {
continue
}
for _, tag := range tags {
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
fi, err := os.Stat(manifestPath)
if err != nil || fi.IsDir() {
continue
}
// Read manifest to get size
data, err := os.ReadFile(manifestPath)
if err != nil {
continue
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
continue
}
var totalSize int64
for _, layer := range manifest.Layers {
totalSize += layer.Size
}
// Build display name using model.Name
n := model.Name{
Host: registry.Name(),
Namespace: namespace.Name(),
Kind: kind.Name(),
Model: skillName.Name(),
Tag: tag.Name(),
}
skills = append(skills, SkillInfo{
Namespace: n.Namespace + "/" + n.Kind,
Name: n.Model,
Tag: n.Tag,
Size: totalSize,
ModifiedAt: fi.ModTime(),
})
}
}
}
}
}
return skills, nil
}
// createSkillManifest creates a manifest for a standalone skill.
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
// Read SKILL.md to extract metadata
skillMdPath := filepath.Join(skillDir, "SKILL.md")
content, err := os.ReadFile(skillMdPath)
if err != nil {
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
}
// Extract name and description from frontmatter
name, description := extractSkillMetadata(string(content))
if name == "" {
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
}
// Create config
config := map[string]any{
"name": name,
"description": description,
"architecture": "amd64",
"os": "linux",
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, nil, fmt.Errorf("marshaling config: %w", err)
}
// Create config layer
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, nil, fmt.Errorf("creating config layer: %w", err)
}
manifest := &server.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: configLayer,
Layers: []server.Layer{layer},
}
return manifest, &configLayer, nil
}
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
func extractSkillMetadata(content string) (name, description string) {
lines := strings.Split(content, "\n")
inFrontmatter := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "---" {
if !inFrontmatter {
inFrontmatter = true
continue
} else {
break // End of frontmatter
}
}
if inFrontmatter {
if strings.HasPrefix(trimmed, "name:") {
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
} else if strings.HasPrefix(trimmed, "description:") {
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
}
}
}
return name, description
}
// NewSkillCommand creates the skill parent command with subcommands.
func NewSkillCommand() *cobra.Command {
skillCmd := &cobra.Command{
Use: "skill",
Short: "Manage skills",
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
}
pushCmd := &cobra.Command{
Use: "push NAME[:TAG] PATH",
Short: "Push a skill to a registry",
Long: "Package a local skill directory and push it to a registry",
Args: cobra.ExactArgs(2),
PreRunE: checkServerHeartbeat,
RunE: SkillPushHandler,
}
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd := &cobra.Command{
Use: "pull NAME[:TAG]",
Short: "Pull a skill from a registry",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: SkillPullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List installed skills",
Args: cobra.NoArgs,
RunE: SkillListHandler,
}
rmCmd := &cobra.Command{
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
Aliases: []string{"remove", "delete"},
Short: "Remove a skill",
Args: cobra.MinimumNArgs(1),
RunE: SkillRemoveHandler,
}
showCmd := &cobra.Command{
Use: "show NAME[:TAG]",
Short: "Show skill details",
Args: cobra.ExactArgs(1),
RunE: SkillShowHandler,
}
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
return skillCmd
}

View File

@@ -1,591 +0,0 @@
package cmd
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/server"
)
const (
skillFileName = "SKILL.md"
maxSkillDescription = 1024
maxSkillNameLength = 64
)
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
type skillMetadata struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
}
type skillDefinition struct {
Name string
Description string
Content string // Full SKILL.md content (without frontmatter)
Dir string
SkillPath string
}
type skillCatalog struct {
Skills []skillDefinition
byName map[string]skillDefinition
}
func loadSkills(paths []string) (*skillCatalog, error) {
if len(paths) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, root := range paths {
info, err := os.Stat(root)
if err != nil {
return nil, fmt.Errorf("skills directory %q: %w", root, err)
}
if !info.IsDir() {
return nil, fmt.Errorf("skills path %q is not a directory", root)
}
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
// Skills can be referenced by:
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
// - Name (local path): loaded from the filesystem (for development)
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
if len(refs) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, ref := range refs {
var skillDir string
if ref.Digest != "" {
// Load from extracted skill cache
path, err := server.GetSkillsPath(ref.Digest)
if err != nil {
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
}
// Check if skill is already extracted
skillMdPath := filepath.Join(path, skillFileName)
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
// Try to extract the skill blob
path, err = server.ExtractSkillBlob(ref.Digest)
if err != nil {
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
}
}
skillDir = path
} else if ref.Name != "" {
// Check if this is a local path or a registry reference
if !server.IsLocalSkillPath(ref.Name) {
// Registry reference without a digest - skill needs to be pulled first
// This happens when an agent references a skill that hasn't been bundled
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
}
// Local path - resolve it
skillPath := ref.Name
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("expanding home directory: %w", err)
}
skillPath = filepath.Join(home, skillPath[1:])
}
absPath, err := filepath.Abs(skillPath)
if err != nil {
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
}
// Check if this is a directory containing skills or a single skill
info, err := os.Stat(absPath)
if err != nil {
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
}
if info.IsDir() {
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
skillMdPath := filepath.Join(absPath, skillFileName)
if _, err := os.Stat(skillMdPath); err == nil {
// Direct skill directory
skillDir = absPath
} else {
// Parent directory - walk to find skill subdirectories
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillSubDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillSubDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
continue
}
} else {
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
}
} else {
// Both empty - skip
continue
}
// Parse the skill from skillDir if set
if skillDir != "" {
skillMdPath := filepath.Join(skillDir, skillFileName)
skill, err := parseSkillFile(skillMdPath, skillDir)
if err != nil {
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
continue
}
byName[skill.Name] = skill
skills = append(skills, skill)
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
rawContent, err := os.ReadFile(path)
if err != nil {
return skillDefinition{}, err
}
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
if err != nil {
return skillDefinition{}, err
}
var meta skillMetadata
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
}
if err := validateSkillMetadata(meta, skillDir); err != nil {
return skillDefinition{}, err
}
absPath, err := filepath.Abs(path)
if err != nil {
return skillDefinition{}, err
}
absDir, err := filepath.Abs(skillDir)
if err != nil {
return skillDefinition{}, err
}
return skillDefinition{
Name: meta.Name,
Description: meta.Description,
Content: bodyContent,
Dir: absDir,
SkillPath: absPath,
}, nil
}
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
scanner := bufio.NewScanner(strings.NewReader(content))
if !scanner.Scan() {
return "", "", errors.New("empty SKILL.md")
}
if strings.TrimSpace(scanner.Text()) != "---" {
return "", "", errors.New("missing YAML frontmatter")
}
var fmLines []string
foundEnd := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "---" {
foundEnd = true
break
}
fmLines = append(fmLines, line)
}
if !foundEnd {
return "", "", errors.New("frontmatter not terminated")
}
// Collect remaining content as body
var bodyLines []string
for scanner.Scan() {
bodyLines = append(bodyLines, scanner.Text())
}
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
}
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
name := strings.TrimSpace(meta.Name)
description := strings.TrimSpace(meta.Description)
switch {
case name == "":
return errors.New("missing skill name")
case len(name) > maxSkillNameLength:
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
case !skillNamePattern.MatchString(name):
return fmt.Errorf("invalid skill name %q", name)
}
if description == "" {
return errors.New("missing skill description")
}
if len(description) > maxSkillDescription {
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
}
// Skip directory name check for digest-based paths (extracted from blobs)
dirName := filepath.Base(skillDir)
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
}
return nil
}
func (c *skillCatalog) SystemPrompt() string {
if c == nil || len(c.Skills) == 0 {
return ""
}
var b strings.Builder
b.WriteString("# Skills\n\n")
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
b.WriteString("## Available Tools\n\n")
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
for _, skill := range c.Skills {
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
fmt.Fprintf(&b, "%s\n\n", skill.Content)
b.WriteString("---\n\n")
}
return b.String()
}
func (c *skillCatalog) Tools() api.Tools {
if c == nil || len(c.Skills) == 0 {
return nil
}
runScriptProps := api.NewToolPropertiesMap()
runScriptProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the script",
})
runScriptProps.Set("command", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
})
readFileProps := api.NewToolPropertiesMap()
readFileProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the file",
})
readFileProps.Set("path", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The relative path to the file within the skill directory",
})
return api.Tools{
{
Type: "function",
Function: api.ToolFunction{
Name: "run_skill_script",
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "command"},
Properties: runScriptProps,
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "read_skill_file",
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "path"},
Properties: readFileProps,
},
},
},
}
}
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
switch call.Function.Name {
case "read_skill_file":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
relPath, err := requireStringArg(call.Function.Arguments, "path")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
content, err := readSkillFile(skill.Dir, relPath)
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
return toolMessage(call, content), true, nil
case "run_skill_script":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
command, err := requireStringArg(call.Function.Arguments, "command")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
output, err := runSkillScript(skill.Dir, command)
if err != nil {
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
}
return toolMessage(call, output), true, nil
default:
return api.Message{}, false, nil
}
}
// runSkillScript executes a shell command within a skill's directory.
//
// SECURITY LIMITATIONS (TODO):
// - No sandboxing: commands run with full user permissions
// - No path validation: model can run any command, not just scripts in skill dir
// - Shell injection risk: sh -c is used, malicious input could be crafted
// - No executable allowlist: any program can be called (curl, rm, etc.)
// - No environment isolation: scripts inherit full environment variables
//
// POTENTIAL IMPROVEMENTS:
// - Restrict commands to only reference files within skill directory
// - Allowlist specific executables (python3, node, bash)
// - Use sandboxing (Docker, nsjail, seccomp)
// - Require explicit script registration in SKILL.md frontmatter
// - Add per-skill configurable timeouts
func runSkillScript(skillDir, command string) (string, error) {
// Validate the skill directory exists
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
if _, err := os.Stat(absSkillDir); err != nil {
return "", fmt.Errorf("skill directory not found: %w", err)
}
// Create command with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = absSkillDir
// Inject the current working directory (where ollama run was called from)
// as an environment variable so scripts can reference files in that directory
workingDir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get working directory: %w", err)
}
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
// Capture both stdout and stderr
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
// Combine output
output := stdout.String()
if stderr.Len() > 0 {
if output != "" {
output += "\n"
}
output += stderr.String()
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("command timed out after 30 seconds")
}
return output, err
}
return output, nil
}
func readSkillFile(skillDir, relPath string) (string, error) {
relPath = filepath.Clean(strings.TrimSpace(relPath))
if relPath == "" {
return "", errors.New("path is required")
}
if filepath.IsAbs(relPath) {
return "", errors.New("path must be relative to the skill directory")
}
target := filepath.Join(skillDir, relPath)
absTarget, err := filepath.Abs(target)
if err != nil {
return "", err
}
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absSkillDir, absTarget)
if err != nil {
return "", err
}
if strings.HasPrefix(rel, "..") {
return "", errors.New("path escapes the skill directory")
}
content, err := os.ReadFile(absTarget)
if err != nil {
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
}
return string(content), nil
}
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
value, ok := args.Get(name)
if !ok {
return "", fmt.Errorf("missing required argument %q", name)
}
str, ok := value.(string)
if !ok {
return "", fmt.Errorf("argument %q must be a string", name)
}
if strings.TrimSpace(str) == "" {
return "", fmt.Errorf("argument %q cannot be empty", name)
}
return str, nil
}
func toolMessage(call api.ToolCall, content string) api.Message {
msg := api.Message{
Role: "tool",
Content: content,
ToolName: call.Function.Name,
}
if call.ID != "" {
msg.ToolCallID = call.ID
}
return msg
}

View File

@@ -6,11 +6,14 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,8 +21,13 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
} `json:"text_config"`
}
@@ -33,8 +41,94 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p AdapterParameters) KV() ggml.KV {
func (p AdapterParameters) KV() KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
alpha = p.LoraParameters.Alpha
}
kv := ggml.KV{
kv := KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelConverter interface {
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -107,7 +206,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
KV(ofs.Config) KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -115,7 +214,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
arch := baseKV.Architecture()
if arch == "" {
return errors.New("architecture not set for the base model")
}
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
return nil, nil, err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
return nil, nil, err
}
if len(p.Architectures) < 1 {
return errors.New("unknown architecture")
return nil, nil, errors.New("unknown architecture")
}
var conv ModelConverter
@@ -217,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
return nil, nil, err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
return nil, nil, err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return err
return nil, nil, err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -254,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -263,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
func (p *bertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
func (p *commandrModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
func (m *deepseekocr) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,7 +1,5 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -9,7 +7,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,6 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,8 +3,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -55,7 +53,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
func (m *gemma3nModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
func (m *gptossModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
func (p *llamaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
func (p *llama4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,6 +7,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,13 +19,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
func (p *mistral3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
func (p *mixtralModel) KV(t *Tokenizer) KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
func (m *mllamaModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
func (p *olmoModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
func (p *phi3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen2Model) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen3Model) KV(t *Tokenizer) KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,6 +19,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -28,7 +29,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k, v := range kv {
for k := range kv.Keys() {
v := kv.Value(k)
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV map[string]any
BaseKV KV
Expected map[string]string
}

View File

@@ -14,6 +14,7 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -0,0 +1,406 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -32,7 +32,9 @@
"codeblocks": "system"
},
"contextual": {
"options": ["copy"]
"options": [
"copy"
]
},
"navbar": {
"links": [
@@ -52,7 +54,9 @@
"display": "simple"
},
"examples": {
"languages": ["curl"]
"languages": [
"curl"
]
}
},
"redirects": [
@@ -97,6 +101,7 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -139,7 +144,8 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility"
"/api/openai-compatibility",
"/api/anthropic-compatibility"
]
},
{

View File

@@ -0,0 +1,69 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: Linux
title: "Linux"
---
## Install
@@ -13,8 +13,7 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -113,11 +112,7 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
</Note>
## Customizing
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -1,548 +0,0 @@
# Ollama Skills
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
## Quick Start
### Creating a Skill
Create a directory with a `SKILL.md` file:
```
my-skill/
├── SKILL.md # Required: Instructions for the agent
└── scripts/ # Optional: Executable scripts
└── run.py
```
The `SKILL.md` file must have YAML frontmatter:
```markdown
---
name: my-skill
description: A brief description of what this skill does
---
# My Skill
## Purpose
Explain what this skill does and when to use it.
## Instructions
Step-by-step instructions for the agent on how to use this skill.
## Examples
Show example inputs and expected outputs.
```
### Using Skills in an Agent
Reference skills in your Agentfile:
```dockerfile
FROM llama3.2:3b
AGENT_TYPE conversational
# Local skill (bundled with agent)
SKILL ./path/to/my-skill
# Registry skill (pulled from ollama.com)
SKILL library/skill/calculator:1.0.0
# User skill from registry
SKILL myname/skill/calculator:1.0.0
SYSTEM You are a helpful assistant.
```
### Managing Skills
```bash
# Push a skill to the registry (uses your namespace)
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
# Pull a skill from the official library
ollama skill pull skill/calculator:1.0.0
# Pull a skill from a user's namespace
ollama skill pull myname/skill/calculator:1.0.0
# List installed skills
ollama skill list
# Show skill details
ollama skill show skill/calculator:1.0.0
# Remove a skill
ollama skill rm skill/calculator:1.0.0
```
### Dynamic Skills in Chat
You can add and remove skills dynamically during an interactive chat session:
```
>>> /skills
Available Skills:
calculator (sha256:abc123def456...)
>>> /skill add ./my-local-skill
Added skill 'my-skill' from ./my-local-skill
>>> /skill list
Skills loaded in this session:
my-skill (local: /path/to/my-local-skill)
>>> /skill remove my-skill
Removed skill 'my-skill'
```
| Command | Description |
|---------|-------------|
| `/skills` | Show all available skills (model + session) |
| `/skill add <path>` | Add a skill from a local path |
| `/skill remove <name>` | Remove a skill by name |
| `/skill list` | List skills loaded in this session |
Dynamic skills take effect on the next message. This is useful for:
- Testing skills during development
- Temporarily adding capabilities to a model
- Experimenting with skill combinations
## Skill Reference Formats
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
| Format | Example | Description |
|--------|---------|-------------|
| Local path | `./skills/calc` | Bundled with agent at create time |
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
The `kind` field distinguishes skills from models:
- `skill` - Skill packages
- `agent` - Agent packages (future)
- (empty) - Regular models
## SKILL.md Structure
### Required Frontmatter
```yaml
---
name: skill-name # Must match directory name
description: Brief description of the skill
---
```
### Recommended Sections
1. **Purpose**: What the skill does and when to use it
2. **When to use**: Trigger conditions for the agent
3. **Instructions**: Step-by-step usage guide
4. **Examples**: Input/output examples
5. **Scripts**: Documentation for any bundled scripts
### Example: Calculator Skill
```markdown
---
name: calculator
description: Performs mathematical calculations using Python
---
# Calculator Skill
## Purpose
This skill performs mathematical calculations using a bundled Python script.
## When to use
- User asks to calculate something
- User wants to do math operations
- Any arithmetic is needed
## Instructions
1. When calculation is needed, use the `run_skill_script` tool
2. Call: `python3 scripts/calculate.py "<expression>"`
3. Return the result to the user
## Examples
**Input**: "What is 25 * 4?"
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
**Output**: "25 * 4 = 100"
```
## Storage Layout
```
~/.ollama/models/
├── blobs/
│ └── sha256-<digest> # Skill tar.gz blob
├── manifests/
│ └── registry.ollama.ai/
│ └── skill/ # Library skills
│ └── calculator/
│ └── 1.0.0
│ └── skill-username/ # User skills
│ └── my-skill/
│ └── latest
└── skills/
└── sha256-<digest>/ # Extracted skill cache
├── SKILL.md
└── scripts/
```
---
# Security Considerations
## Current State (Development)
The current implementation has several security considerations that need to be addressed before production use.
### 1. Script Execution
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
**Current behavior**:
- Scripts run with the same permissions as the Ollama process
- No sandboxing or isolation
- Full filesystem access
**Mitigations needed**:
- [ ] Sandbox script execution (containers, seccomp, etc.)
- [ ] Resource limits (CPU, memory, time)
- [ ] Filesystem isolation (read-only mounts, restricted paths)
- [ ] Network policy controls
- [ ] Capability dropping
### 2. Skill Provenance
**Risk**: Malicious skills could be pushed to the registry.
**Current behavior**:
- No code signing or verification
- No malware scanning
- Trust based on namespace ownership
**Mitigations needed**:
- [ ] Skill signing with author keys
- [ ] Registry-side malware scanning
- [ ] Content policy enforcement
- [ ] Reputation system for skill authors
### 3. Namespace Squatting
**Risk**: Malicious actors could register skill names that impersonate official tools.
**Current behavior**:
- First-come-first-served namespace registration
- No verification of skill names
**Mitigations needed**:
- [ ] Reserved namespace list (official tools, common names)
- [ ] Trademark/name verification for popular skills
- [ ] Clear namespacing conventions
### 4. Supply Chain Attacks
**Risk**: Compromised skills could inject malicious code into agents.
**Current behavior**:
- Skills pulled without integrity verification beyond digest
- No dependency tracking
**Mitigations needed**:
- [ ] SBOM (Software Bill of Materials) for skills
- [ ] Dependency vulnerability scanning
- [ ] Pinned versions in Agentfiles
- [ ] Audit logging of skill usage
### 5. Data Exfiltration
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
**Current behavior**:
- Skills have access to conversation context
- Scripts can make network requests
**Mitigations needed**:
- [ ] Network egress controls
- [ ] Sensitive data detection/masking
- [ ] Audit logging of script network activity
- [ ] User consent for data access
### 6. Privilege Escalation
**Risk**: Skills could escalate privileges through script execution.
**Current behavior**:
- Scripts inherit Ollama process privileges
- No capability restrictions
**Mitigations needed**:
- [ ] Run scripts as unprivileged user
- [ ] Drop all capabilities
- [ ] Mandatory access controls (SELinux/AppArmor)
## Recommended Security Model
### Skill Trust Levels
```
┌─────────────────────────────────────────────────────────────┐
│ Level 0: Untrusted (default) │
│ - No script execution │
│ - Instructions only │
│ - Safe for any skill │
├─────────────────────────────────────────────────────────────┤
│ Level 1: Sandboxed │
│ - Scripts run in isolated container │
│ - No network access │
│ - Read-only filesystem │
│ - Resource limits enforced │
├─────────────────────────────────────────────────────────────┤
│ Level 2: Trusted │
│ - Scripts run with network access │
│ - Can write to designated directories │
│ - Requires explicit user approval │
├─────────────────────────────────────────────────────────────┤
│ Level 3: Privileged (admin only) │
│ - Full host access │
│ - System administration skills │
│ - Requires admin approval │
└─────────────────────────────────────────────────────────────┘
```
### Skill Manifest Security Fields (Future)
```yaml
---
name: my-skill
description: A skill description
security:
trust_level: sandboxed
permissions:
- network:read # Can make HTTP GET requests
- filesystem:read:/data # Can read from /data
resource_limits:
max_memory: 256MB
max_cpu_time: 30s
max_disk: 100MB
signature: sha256:abc... # Author signature
---
```
---
# Future Considerations
## Feature Roadmap
### Phase 1: Foundation (Current)
- [x] Skill bundling with agents
- [x] Local skill development
- [x] Basic CLI commands (push, pull, list, rm, show)
- [x] Registry blob storage
- [ ] Registry namespace configuration
### Phase 2: Security
- [ ] Script sandboxing
- [ ] Permission model
- [ ] Skill signing
- [ ] Audit logging
### Phase 3: Discovery
- [ ] Skill search on ollama.com
- [ ] Skill ratings and reviews
- [ ] Usage analytics
- [ ] Featured/trending skills
### Phase 4: Advanced Features
- [ ] Skill dependencies
- [ ] Skill versioning constraints
- [ ] Skill composition (skills using skills)
- [ ] Skill testing framework
## Open Questions
### 1. Skill Execution Model
**Question**: How should skills execute scripts?
Options:
- **A) In-process**: Fast but unsafe
- **B) Subprocess**: Current approach, moderate isolation
- **C) Container**: Good isolation, requires container runtime
- **D) WASM**: Portable and safe, limited capabilities
- **E) Remote execution**: Offload to secure service
### 2. Skill Versioning
**Question**: How strict should version pinning be?
Options:
- **A) Always latest**: Simple but risky
- **B) Semantic versioning**: `^1.0.0` allows minor updates
- **C) Exact pinning**: `=1.0.0` requires explicit updates
- **D) Digest pinning**: `@sha256:abc` immutable reference
### 3. Skill Permissions
**Question**: How should users grant permissions to skills?
Options:
- **A) All or nothing**: Accept all permissions or don't use
- **B) Granular consent**: Approve each permission individually
- **C) Trust levels**: Pre-defined permission bundles
- **D) Runtime prompts**: Ask when permission is first used
### 4. Skill Discovery
**Question**: How should users find skills?
Options:
- **A) Central registry only**: ollama.com/skills
- **B) Federated registries**: Multiple skill sources
- **C) Git repositories**: Pull from GitHub, etc.
- **D) All of the above**: Multiple discovery mechanisms
### 5. Skill Monetization
**Question**: Should skill authors be able to monetize?
Options:
- **A) Free only**: All skills are free and open
- **B) Paid skills**: Authors can charge for skills
- **C) Freemium**: Free tier with paid features
- **D) Donations**: Voluntary support for authors
### 6. Skill Updates
**Question**: How should skill updates be handled?
Options:
- **A) Manual**: User explicitly updates
- **B) Auto-update**: Always use latest
- **C) Notify**: Alert user to available updates
- **D) Policy-based**: Organization controls update policy
## API Considerations
### Skill Metadata API
```
GET /api/skills
GET /api/skills/:namespace/:name
GET /api/skills/:namespace/:name/versions
GET /api/skills/:namespace/:name/readme
```
### Skill Execution API
```
POST /api/skills/:namespace/:name/execute
{
"command": "python3 scripts/run.py",
"args": ["--input", "data"],
"timeout": 30
}
```
### Skill Permissions API
```
GET /api/skills/:namespace/:name/permissions
POST /api/skills/:namespace/:name/permissions/grant
DELETE /api/skills/:namespace/:name/permissions/revoke
```
## Testing Considerations
### Skill Testing Framework
```bash
# Run skill tests
ollama skill test ./my-skill
# Test with specific model
ollama skill test ./my-skill --model llama3.2:3b
# Generate test report
ollama skill test ./my-skill --report
```
### Test File Format
```yaml
# my-skill/tests/test.yaml
tests:
- name: "basic calculation"
input: "What is 2 + 2?"
expect:
contains: "4"
tool_called: "run_skill_script"
- name: "complex expression"
input: "Calculate 15% of 200"
expect:
contains: "30"
```
## Compatibility Considerations
### Minimum Ollama Version
Skills should declare minimum Ollama version:
```yaml
---
name: my-skill
requires:
ollama: ">=0.4.0"
---
```
### Model Compatibility
Skills may require specific model capabilities:
```yaml
---
name: vision-skill
requires:
capabilities:
- vision
- tools
---
```
## Migration Path
### From Local to Registry
```bash
# Develop locally
SKILL ./my-skill
# Push when ready
ollama skill push myname/my-skill:1.0.0 ./my-skill
# Update Agentfile
SKILL skill/myname/my-skill:1.0.0
```
### Version Upgrades
```bash
# Check for updates
ollama skill outdated
# Update specific skill
ollama skill update calculator:1.0.0
# Update all skills
ollama skill update --all
```

View File

@@ -148,16 +148,6 @@ func Remotes() []string {
return r
}
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
// Returns empty slice if not configured.
func Skills() []string {
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
if raw == "" {
return []string{}
}
return strings.Split(raw, ",")
}
func BoolWithDefault(k string) func(defaultValue bool) bool {
return func(defaultValue bool) bool {
if s := Var(k); s != "" {
@@ -327,9 +317,6 @@ func AsMap() map[string]EnvVar {
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
}
// Skills configuration would go here when added
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
return ret
}

View File

@@ -1,5 +1,7 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -11,4 +13,8 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -239,6 +241,18 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
return err
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
return err
}
}

2
go.mod
View File

@@ -87,5 +87,5 @@ require (
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

149
middleware/anthropic.go Normal file
View File

@@ -0,0 +1,149 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -0,0 +1,584 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
@@ -59,8 +58,6 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
var messages []api.Message
var licenses []string
var skills []api.SkillRef
var mcps []api.MCPRef
params := make(map[string]any)
for _, c := range f.Commands {
@@ -121,32 +118,6 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
case "skill":
skillName := c.Args
// Expand local paths relative to the Agentfile directory
if isLocalPath(skillName) {
expanded, err := expandPath(skillName, relativeDir)
if err != nil {
return nil, fmt.Errorf("expanding skill path %q: %w", skillName, err)
}
skillName = expanded
}
skills = append(skills, api.SkillRef{Name: skillName})
case "mcp":
mcpRef, err := parseMCPArg(c.Args, relativeDir)
if err != nil {
return nil, fmt.Errorf("invalid MCP: %w", err)
}
mcps = append(mcps, mcpRef)
case "agent_type":
// Handle "AGENT TYPE conversational" -> strip "TYPE " prefix
args := c.Args
if strings.HasPrefix(strings.ToLower(args), "type ") {
args = strings.TrimSpace(args[5:])
}
req.AgentType = args
case "entrypoint":
req.Entrypoint = c.Args
default:
if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
@@ -179,12 +150,6 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
if len(licenses) > 0 {
req.License = licenses
}
if len(skills) > 0 {
req.Skills = skills
}
if len(mcps) > 0 {
req.MCPs = mcps
}
return req, nil
}
@@ -368,7 +333,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type", "entrypoint":
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@@ -394,7 +359,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", \"agent_type\", \"mcp\", or \"entrypoint\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
)
type ParserError struct {
@@ -458,9 +423,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
switch s := strings.ToLower(b.String()); s {
case "from":
cmd.Name = "model"
case "agent":
// "AGENT TYPE" -> "agent_type", consume next word
cmd.Name = "agent_type"
case "parameter":
// transition to stateParameter which sets command name
next = stateParameter
@@ -538,10 +500,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
if cmd.Name == "model" {
return &f, nil
}
// Allow entrypoint-only agents without FROM
if cmd.Name == "entrypoint" {
return &f, nil
}
}
return nil, errMissingFrom
@@ -560,7 +518,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
}
case stateName:
switch {
case isAlpha(r), r == '_':
case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
@@ -661,7 +619,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type", "agent", "mcp", "entrypoint":
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
return true
default:
return false
@@ -708,79 +666,3 @@ func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User
func expandPath(path, relativeDir string) (string, error) {
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
}
// parseMCPArg parses MCP command arguments.
// Supports two formats:
//
// JSON: {"name": "web-search", "command": "uv", "args": ["run", "./script.py"]}
// Simple: web-search uv run ./script.py (name, command, args...)
func parseMCPArg(args string, relativeDir string) (api.MCPRef, error) {
args = strings.TrimSpace(args)
if args == "" {
return api.MCPRef{}, errors.New("MCP requires arguments")
}
// Try JSON format first
if strings.HasPrefix(args, "{") {
var ref api.MCPRef
if err := json.Unmarshal([]byte(args), &ref); err != nil {
return api.MCPRef{}, fmt.Errorf("invalid JSON: %w", err)
}
if ref.Name == "" {
return api.MCPRef{}, errors.New("MCP name is required")
}
if ref.Command == "" {
return api.MCPRef{}, errors.New("MCP command is required")
}
if ref.Type == "" {
ref.Type = "stdio"
}
// Expand relative paths in args
for i, arg := range ref.Args {
if isLocalPath(arg) {
expanded, err := expandPath(arg, relativeDir)
if err != nil {
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
}
ref.Args[i] = expanded
}
}
return ref, nil
}
// Simple format: name command args...
parts := strings.Fields(args)
if len(parts) < 2 {
return api.MCPRef{}, errors.New("MCP requires at least name and command")
}
ref := api.MCPRef{
Name: parts[0],
Command: parts[1],
Type: "stdio",
}
if len(parts) > 2 {
ref.Args = parts[2:]
}
// Expand relative paths in args
for i, arg := range ref.Args {
if isLocalPath(arg) {
expanded, err := expandPath(arg, relativeDir)
if err != nil {
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
}
ref.Args[i] = expanded
}
}
return ref, nil
}
// isLocalPath checks if a string looks like a local filesystem path.
func isLocalPath(s string) bool {
return strings.HasPrefix(s, "/") ||
strings.HasPrefix(s, "./") ||
strings.HasPrefix(s, "../") ||
strings.HasPrefix(s, "~")
}

View File

@@ -21,6 +21,7 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -801,7 +802,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
base := map[string]any{"general.architecture": "test"}
var base convert.KV = map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

33
progress/stepbar.go Normal file
View File

@@ -0,0 +1,33 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -18,6 +18,7 @@ const (
CharCtrlL = 12
CharEnter = 13
CharNext = 14
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
CharPrev = 16
CharBckSearch = 18
CharFwdSearch = 19

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -11,12 +12,19 @@ func Execute(args []string) error {
}
var newRunner bool
if args[0] == "--ollama-engine" {
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if newRunner {
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -42,18 +42,39 @@ shift $(( $OPTIND - 1 ))
_build_darwin() {
for ARCH in $ARCHS; do
status "Building darwin $ARCH"
INSTALL_PREFIX=dist/darwin-$ARCH/
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
INSTALL_PREFIX=dist/darwin-$ARCH/
if [ "$ARCH" = "amd64" ]; then
status "Building darwin $ARCH dynamic backends"
cmake -B build/darwin-$ARCH \
BUILD_DIR=build/darwin-$ARCH
cmake -B $BUILD_DIR \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
-DMLX_ENGINE=ON \
-DMLX_ENABLE_X64_MAC=ON \
-DOLLAMA_RUNNER_DIR=./
cmake --build $BUILD_DIR --target ggml-cpu -j
cmake --build $BUILD_DIR --target mlx mlxc -j
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
-DOLLAMA_RUNNER_DIR=./ \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
cmake --build build/darwin-$ARCH --target ggml-cpu -j
cmake --install build/darwin-$ARCH --component CPU
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -61,10 +82,12 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -131,17 +154,23 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -149,7 +178,7 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then

View File

@@ -12,6 +12,17 @@ set -eu
. $(dirname $0)/env.sh
# Check for required tools
if ! command -v zstd >/dev/null 2>&1; then
echo "ERROR: zstd is required but not installed." >&2
echo "Please install zstd:" >&2
echo " - macOS: brew install zstd" >&2
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
echo " - Arch: sudo pacman -S zstd" >&2
exit 1
fi
mkdir -p dist
docker buildx build \
@@ -37,19 +48,68 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
fi

View File

@@ -66,6 +66,36 @@ if [ -n "$NEEDS" ]; then
exit 1
fi
# Function to download and extract with fallback from zst to tgz
download_and_extract() {
local url_base="$1"
local dest_dir="$2"
local filename="$3"
# Check if .tar.zst is available
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
# zst file exists - check if we have zstd tool
if ! available zstd; then
error "This version requires zstd for extraction. Please install zstd and try again:
- Debian/Ubuntu: sudo apt-get install zstd
- RHEL/CentOS/Fedora: sudo dnf install zstd
- Arch: sudo pacman -S zstd"
fi
status "Downloading ${filename}.tar.zst"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
return 0
fi
# Fall back to .tgz for older versions
status "Downloading ${filename}.tgz"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "${dest_dir}"
}
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
@@ -78,10 +108,7 @@ fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
@@ -91,15 +118,9 @@ fi
# Check for NVIDIA JetPack systems with additional downloads
if [ -f /etc/nv_tegra_release ] ; then
if grep R36 /etc/nv_tegra_release > /dev/null ; then
status "Downloading JetPack 6 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
status "Downloading JetPack 5 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
else
warning "Unsupported JetPack version detected. GPU may not be supported"
fi
@@ -222,10 +243,7 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
status "Downloading Linux ROCm ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
install_success
status "AMD GPU ready."

View File

@@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -62,10 +63,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
config.Renderer = r.Renderer
config.Parser = r.Parser
config.Requires = r.Requires
config.Skills = r.Skills
config.MCPs = r.MCPs
config.AgentType = r.AgentType
config.Entrypoint = r.Entrypoint
for v := range r.Files {
if !fs.ValidPath(v) {
@@ -125,10 +122,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()}
}
// Inherit config from base model (Renderer, Parser, Requires, Capabilities, etc.)
// This is especially important for cloud models which don't have GGUF files
// to detect capabilities from.
if err == nil && !remote {
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
@@ -145,29 +139,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
if config.Requires == "" {
config.Requires = baseConfig.Requires
}
// Inherit capabilities for cloud/remote models
// (local models detect capabilities from GGUF file)
if len(config.Capabilities) == 0 && len(baseConfig.Capabilities) > 0 {
config.Capabilities = baseConfig.Capabilities
}
// Inherit remote host/model if base is a cloud model
if config.RemoteHost == "" && baseConfig.RemoteHost != "" {
config.RemoteHost = baseConfig.RemoteHost
}
if config.RemoteModel == "" && baseConfig.RemoteModel != "" {
config.RemoteModel = baseConfig.RemoteModel
}
// Inherit model family for proper rendering
if config.ModelFamily == "" && baseConfig.ModelFamily != "" {
config.ModelFamily = baseConfig.ModelFamily
}
if len(config.ModelFamilies) == 0 && len(baseConfig.ModelFamilies) > 0 {
config.ModelFamilies = baseConfig.ModelFamilies
}
// Inherit context length for cloud models
if config.ContextLen == 0 && baseConfig.ContextLen > 0 {
config.ContextLen = baseConfig.ContextLen
}
}
cfgFile.Close()
}
@@ -187,9 +158,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()}
return
}
} else if r.Entrypoint != "" {
// Entrypoint-only agent: no base model needed
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
} else {
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
return
@@ -487,7 +455,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil
@@ -576,18 +544,6 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err
}
// Handle skill layers for agents
layers, config.Skills, err = setSkillLayers(layers, config.Skills, fn)
if err != nil {
return err
}
// Handle MCP layers for agents
layers, config.MCPs, err = setMCPLayers(layers, config.MCPs, fn)
if err != nil {
return err
}
configLayer, err := createConfigLayer(layers, *config)
if err != nil {
return err
@@ -838,135 +794,6 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
// setSkillLayers creates skill layers for local skill paths and updates the skill refs.
// Local paths are converted to bundled skill layers with digests.
// Registry references are kept as-is for later resolution during pull.
func setSkillLayers(layers []Layer, skills []model.SkillRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.SkillRef, error) {
if len(skills) == 0 {
return layers, skills, nil
}
// Remove any existing skill layers
layers = removeLayer(layers, MediaTypeSkill)
var updatedSkills []model.SkillRef
for _, skill := range skills {
// Check if this is a local path
if IsLocalSkillPath(skill.Name) {
// Expand home directory if needed
skillPath := skill.Name
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, nil, fmt.Errorf("expanding home directory: %w", err)
}
skillPath = filepath.Join(home, skillPath[1:])
}
// Make absolute
absPath, err := filepath.Abs(skillPath)
if err != nil {
return nil, nil, fmt.Errorf("resolving skill path %q: %w", skill.Name, err)
}
// Check if this is a direct skill directory or a parent containing skills
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err == nil {
// Direct skill directory
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", filepath.Base(absPath))})
layer, err := CreateSkillLayer(absPath)
if err != nil {
return nil, nil, fmt.Errorf("creating skill layer for %q: %w", skill.Name, err)
}
layers = append(layers, layer)
updatedSkills = append(updatedSkills, model.SkillRef{
Name: filepath.Base(absPath),
Digest: layer.Digest,
})
} else {
// Parent directory - walk to find skill subdirectories
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != "SKILL.md" {
return nil
}
skillDir := filepath.Dir(path)
skillName := filepath.Base(skillDir)
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", skillName)})
layer, err := CreateSkillLayer(skillDir)
if err != nil {
return fmt.Errorf("creating skill layer for %q: %w", skillDir, err)
}
layers = append(layers, layer)
updatedSkills = append(updatedSkills, model.SkillRef{
Name: skillName,
Digest: layer.Digest,
})
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("walking skill directory %q: %w", skill.Name, err)
}
}
} else if skill.Digest != "" {
// Already has a digest (from a pulled agent), keep as-is
updatedSkills = append(updatedSkills, skill)
} else {
// Registry reference - keep as-is for later resolution
updatedSkills = append(updatedSkills, skill)
}
}
return layers, updatedSkills, nil
}
// setMCPLayers handles MCP server references.
// Currently, MCPs are stored as config data (command/args).
// Future: support bundling MCP server directories as layers.
func setMCPLayers(layers []Layer, mcps []model.MCPRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.MCPRef, error) {
if len(mcps) == 0 {
return layers, mcps, nil
}
// Remove any existing MCP layers
layers = removeLayer(layers, MediaTypeMCP)
var updatedMCPs []model.MCPRef
for _, mcp := range mcps {
// Validate MCP has required fields
if mcp.Name == "" {
return nil, nil, fmt.Errorf("MCP server requires a name")
}
if mcp.Command == "" {
return nil, nil, fmt.Errorf("MCP server %q requires a command", mcp.Name)
}
// Set default type if not specified
if mcp.Type == "" {
mcp.Type = "stdio"
}
// For now, just keep MCPs as config data
// Future: detect local paths in args and bundle them
updatedMCPs = append(updatedMCPs, mcp)
fn(api.ProgressResponse{Status: fmt.Sprintf("configuring MCP: %s", mcp.Name)})
}
return layers, updatedMCPs, nil
}
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {

View File

@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -73,6 +74,11 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -232,13 +238,6 @@ func (m *Model) String() string {
})
}
if m.Config.Entrypoint != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "entrypoint",
Args: m.Config.Entrypoint,
})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
@@ -562,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -627,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -641,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -650,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -664,15 +687,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
// Extract skill layers to the skills cache
for _, layer := range manifest.Layers {
if layer.MediaType == MediaTypeSkill {
fn(api.ProgressResponse{Status: fmt.Sprintf("extracting skill %s", layer.Digest)})
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
return fmt.Errorf("extracting skill layer %s: %w", layer.Digest, err)
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
@@ -707,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)

View File

@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,9 +13,14 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -129,30 +129,11 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
return nil, err
}
// Find both 4-part (models) and 5-part (skills/agents) manifest paths
matches4, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
matches5, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*", "*"))
if err != nil {
return nil, err
}
// Combine matches, filtering to only include files
var matches []string
for _, match := range matches4 {
fi, err := os.Stat(match)
if err == nil && !fi.IsDir() {
matches = append(matches, match)
}
}
for _, match := range matches5 {
fi, err := os.Stat(match)
if err == nil && !fi.IsDir() {
matches = append(matches, match)
}
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {

View File

@@ -1,315 +0,0 @@
package server
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// MediaTypeMCP is the media type for MCP server layers in manifests.
const MediaTypeMCP = "application/vnd.ollama.image.mcp"
// GetMCPsPath returns the path to the extracted MCPs cache directory.
// If digest is empty, returns the mcps directory itself.
// If digest is provided, returns the path to the extracted MCP for that digest.
func GetMCPsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "mcps", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// ExtractMCPBlob extracts an MCP tar.gz blob to the mcps cache.
// The blob is expected to be at the blobs path for the given digest.
// Returns the path to the extracted MCP directory.
func ExtractMCPBlob(digest string) (string, error) {
// Get the blob path
blobPath, err := GetBlobsPath(digest)
if err != nil {
return "", fmt.Errorf("getting blob path: %w", err)
}
// Get the extraction path
mcpPath, err := GetMCPsPath(digest)
if err != nil {
return "", fmt.Errorf("getting mcp path: %w", err)
}
// Check if already extracted (look for any file)
entries, err := os.ReadDir(mcpPath)
if err == nil && len(entries) > 0 {
return mcpPath, nil
}
// Open the blob
f, err := os.Open(blobPath)
if err != nil {
return "", fmt.Errorf("opening blob: %w", err)
}
defer f.Close()
// Create gzip reader
gzr, err := gzip.NewReader(f)
if err != nil {
return "", fmt.Errorf("creating gzip reader: %w", err)
}
defer gzr.Close()
// Create tar reader
tr := tar.NewReader(gzr)
// Create the mcp directory
if err := os.MkdirAll(mcpPath, 0o755); err != nil {
return "", fmt.Errorf("creating mcp directory: %w", err)
}
// Extract files
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("reading tar: %w", err)
}
// Clean the name and ensure it doesn't escape the target directory
name := filepath.Clean(header.Name)
if strings.HasPrefix(name, "..") {
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
}
target := filepath.Join(mcpPath, name)
// Verify the target is within mcpPath
if !strings.HasPrefix(target, filepath.Clean(mcpPath)+string(os.PathSeparator)) && target != filepath.Clean(mcpPath) {
return "", fmt.Errorf("path escapes mcp directory: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0o755); err != nil {
return "", fmt.Errorf("creating directory: %w", err)
}
case tar.TypeReg:
// Ensure parent directory exists
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return "", fmt.Errorf("creating parent directory: %w", err)
}
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", fmt.Errorf("creating file: %w", err)
}
if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close()
return "", fmt.Errorf("writing file: %w", err)
}
outFile.Close()
}
}
return mcpPath, nil
}
// CreateMCPLayer creates an MCP layer from a local directory.
// The directory can optionally contain an mcp.json or package.json file.
// Returns the created layer.
func CreateMCPLayer(mcpDir string) (Layer, error) {
// Verify directory exists
info, err := os.Stat(mcpDir)
if err != nil {
return Layer{}, fmt.Errorf("mcp directory not found: %w", err)
}
if !info.IsDir() {
return Layer{}, fmt.Errorf("mcp path is not a directory: %s", mcpDir)
}
// Create a temporary file for the tar.gz
blobsPath, err := GetBlobsPath("")
if err != nil {
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
}
tmpFile, err := os.CreateTemp(blobsPath, "mcp-*.tar.gz")
if err != nil {
return Layer{}, fmt.Errorf("creating temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
tmpFile.Close()
os.Remove(tmpPath)
}()
// Create gzip writer
gzw := gzip.NewWriter(tmpFile)
defer gzw.Close()
// Create tar writer
tw := tar.NewWriter(gzw)
defer tw.Close()
// Walk the mcp directory and add files to tar
err = filepath.Walk(mcpDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Get relative path
relPath, err := filepath.Rel(mcpDir, path)
if err != nil {
return err
}
// Skip the root directory itself
if relPath == "." {
return nil
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
return err
}
// Write file contents if it's a regular file
if !info.IsDir() {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(tw, f); err != nil {
return err
}
}
return nil
})
if err != nil {
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
}
// Close writers to flush
if err := tw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
}
if err := gzw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
}
if err := tmpFile.Close(); err != nil {
return Layer{}, fmt.Errorf("closing temp file: %w", err)
}
// Open the temp file for reading
tmpFile, err = os.Open(tmpPath)
if err != nil {
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
}
defer tmpFile.Close()
// Create the layer (this will compute the digest and move to blobs)
layer, err := NewLayer(tmpFile, MediaTypeMCP)
if err != nil {
return Layer{}, fmt.Errorf("creating layer: %w", err)
}
// Extract the mcp to the cache so it's ready to use
if _, err := ExtractMCPBlob(layer.Digest); err != nil {
return Layer{}, fmt.Errorf("extracting mcp: %w", err)
}
return layer, nil
}
// IsLocalMCPPath checks if an MCP reference looks like a local path.
// Local paths are explicitly prefixed with /, ./, ../, or ~.
func IsLocalMCPPath(name string) bool {
return strings.HasPrefix(name, "/") ||
strings.HasPrefix(name, "./") ||
strings.HasPrefix(name, "../") ||
strings.HasPrefix(name, "~")
}
// MCPNamespace is the namespace used for standalone MCPs in the registry.
const MCPNamespace = "mcp"
// IsMCPReference checks if a name refers to an MCP (has mcp/ prefix).
func IsMCPReference(name string) bool {
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
// mcp/name or mcp/name:tag
if len(parts) >= 1 && parts[0] == MCPNamespace {
return true
}
// namespace/mcp/name (e.g., myuser/mcp/websearch)
if len(parts) >= 2 && parts[1] == MCPNamespace {
return true
}
return false
}
// ParseMCPName parses an MCP reference string into a model.Name.
// The Kind field is set to "mcp".
func ParseMCPName(name string) model.Name {
n := model.ParseName(name)
// If Kind wasn't set (old format without mcp/), set it
if n.Kind == "" {
n.Kind = MCPNamespace
}
return n
}
// GetMCPManifestPath returns the path to the MCP manifest file.
func GetMCPManifestPath(n model.Name) (string, error) {
if n.Model == "" {
return "", fmt.Errorf("mcp name is required")
}
// Ensure Kind is set
if n.Kind == "" {
n.Kind = MCPNamespace
}
path := filepath.Join(
envconfig.Models(),
"manifests",
n.Filepath(),
)
return path, nil
}

View File

@@ -18,7 +18,6 @@ type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Kind string // Optional: "skill", "agent", or empty for models
Repository string
Tag string
}
@@ -43,7 +42,6 @@ func ParseModelPath(name string) ModelPath {
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Kind: "",
Repository: "",
Tag: DefaultTag,
}
@@ -57,41 +55,13 @@ func ParseModelPath(name string) ModelPath {
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 4:
// host/namespace/kind/model or host/namespace/model:tag with kind
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
if model.ValidKinds[parts[2]] {
mp.Kind = parts[2]
mp.Repository = parts[3]
} else {
// Not a valid kind, treat as old format with extra part
mp.Repository = parts[3]
}
case 3:
// Could be: host/namespace/model OR namespace/kind/model
if model.ValidKinds[parts[1]] {
// namespace/kind/model
mp.Namespace = parts[0]
mp.Kind = parts[1]
mp.Repository = parts[2]
} else {
// host/namespace/model
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
}
mp.Repository = parts[2]
case 2:
// Could be: namespace/model OR kind/model
if model.ValidKinds[parts[0]] {
// kind/model (library skill)
mp.Kind = parts[0]
mp.Repository = parts[1]
} else {
// namespace/model
mp.Namespace = parts[0]
mp.Repository = parts[1]
}
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
@@ -105,35 +75,20 @@ func ParseModelPath(name string) ModelPath {
}
func (mp ModelPath) GetNamespaceRepository() string {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s", mp.Namespace, mp.Kind, mp.Repository)
}
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s:%s", mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s:%s", mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
@@ -142,7 +97,6 @@ func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Kind: mp.Kind,
Model: mp.Repository,
Tag: mp.Tag,
}

View File

@@ -50,6 +50,8 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -978,9 +1009,6 @@ func getExistingName(n model.Name) (model.Name, error) {
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Kind == "" && strings.EqualFold(e.Kind, n.Kind) {
n.Kind = e.Kind
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
@@ -1096,6 +1124,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1119,10 +1156,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
Skills: m.Config.Skills,
MCPs: m.Config.MCPs,
AgentType: m.Config.AgentType,
Entrypoint: m.Config.Entrypoint,
}
if m.Config.RemoteHost != "" {
@@ -1177,13 +1210,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model or a skill
// skip loading tensor information if this is a remote model
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
}
// Skills don't have model weights, skip tensor loading
if m.ModelPath == "" {
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
@@ -1556,6 +1588,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -22,6 +22,7 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -41,7 +42,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
base := map[string]any{"general.architecture": "test"}
var base convert.KV = map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -543,6 +552,48 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -6,6 +6,7 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}

View File

@@ -1,326 +0,0 @@
package server
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// MediaTypeSkill is the media type for skill layers in manifests.
const MediaTypeSkill = "application/vnd.ollama.image.skill"
// GetSkillsPath returns the path to the extracted skills cache directory.
// If digest is empty, returns the skills directory itself.
// If digest is provided, returns the path to the extracted skill for that digest.
func GetSkillsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "skills", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// ExtractSkillBlob extracts a skill tar.gz blob to the skills cache.
// The blob is expected to be at the blobs path for the given digest.
// Returns the path to the extracted skill directory.
func ExtractSkillBlob(digest string) (string, error) {
// Get the blob path
blobPath, err := GetBlobsPath(digest)
if err != nil {
return "", fmt.Errorf("getting blob path: %w", err)
}
// Get the extraction path
skillPath, err := GetSkillsPath(digest)
if err != nil {
return "", fmt.Errorf("getting skill path: %w", err)
}
// Check if already extracted
if _, err := os.Stat(filepath.Join(skillPath, "SKILL.md")); err == nil {
return skillPath, nil
}
// Open the blob
f, err := os.Open(blobPath)
if err != nil {
return "", fmt.Errorf("opening blob: %w", err)
}
defer f.Close()
// Create gzip reader
gzr, err := gzip.NewReader(f)
if err != nil {
return "", fmt.Errorf("creating gzip reader: %w", err)
}
defer gzr.Close()
// Create tar reader
tr := tar.NewReader(gzr)
// Create the skill directory
if err := os.MkdirAll(skillPath, 0o755); err != nil {
return "", fmt.Errorf("creating skill directory: %w", err)
}
// Extract files
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("reading tar: %w", err)
}
// Clean the name and ensure it doesn't escape the target directory
name := filepath.Clean(header.Name)
if strings.HasPrefix(name, "..") {
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
}
target := filepath.Join(skillPath, name)
// Verify the target is within skillPath
if !strings.HasPrefix(target, filepath.Clean(skillPath)+string(os.PathSeparator)) && target != filepath.Clean(skillPath) {
return "", fmt.Errorf("path escapes skill directory: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0o755); err != nil {
return "", fmt.Errorf("creating directory: %w", err)
}
case tar.TypeReg:
// Ensure parent directory exists
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return "", fmt.Errorf("creating parent directory: %w", err)
}
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", fmt.Errorf("creating file: %w", err)
}
if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close()
return "", fmt.Errorf("writing file: %w", err)
}
outFile.Close()
}
}
return skillPath, nil
}
// CreateSkillLayer creates a skill layer from a local directory.
// The directory must contain a SKILL.md file.
// Returns the created layer.
func CreateSkillLayer(skillDir string) (Layer, error) {
// Verify SKILL.md exists
skillMdPath := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
return Layer{}, fmt.Errorf("skill directory must contain SKILL.md: %w", err)
}
// Create a temporary file for the tar.gz
blobsPath, err := GetBlobsPath("")
if err != nil {
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
}
tmpFile, err := os.CreateTemp(blobsPath, "skill-*.tar.gz")
if err != nil {
return Layer{}, fmt.Errorf("creating temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
tmpFile.Close()
os.Remove(tmpPath)
}()
// Create gzip writer
gzw := gzip.NewWriter(tmpFile)
defer gzw.Close()
// Create tar writer
tw := tar.NewWriter(gzw)
defer tw.Close()
// Walk the skill directory and add files to tar
err = filepath.Walk(skillDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Get relative path
relPath, err := filepath.Rel(skillDir, path)
if err != nil {
return err
}
// Skip the root directory itself
if relPath == "." {
return nil
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
return err
}
// Write file contents if it's a regular file
if !info.IsDir() {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(tw, f); err != nil {
return err
}
}
return nil
})
if err != nil {
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
}
// Close writers to flush
if err := tw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
}
if err := gzw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
}
if err := tmpFile.Close(); err != nil {
return Layer{}, fmt.Errorf("closing temp file: %w", err)
}
// Open the temp file for reading
tmpFile, err = os.Open(tmpPath)
if err != nil {
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
}
defer tmpFile.Close()
// Create the layer (this will compute the digest and move to blobs)
layer, err := NewLayer(tmpFile, MediaTypeSkill)
if err != nil {
return Layer{}, fmt.Errorf("creating layer: %w", err)
}
// Extract the skill to the cache so it's ready to use
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
return Layer{}, fmt.Errorf("extracting skill: %w", err)
}
return layer, nil
}
// IsLocalSkillPath checks if a skill reference looks like a local path.
// Local paths are explicitly prefixed with /, ./, ../, or ~.
// Registry references like "skill/calculator:1.0.0" should NOT be treated as local paths.
func IsLocalSkillPath(name string) bool {
// Local paths are explicitly indicated by path prefixes
return strings.HasPrefix(name, "/") ||
strings.HasPrefix(name, "./") ||
strings.HasPrefix(name, "../") ||
strings.HasPrefix(name, "~")
}
// SkillNamespace is the namespace used for standalone skills in the registry.
const SkillNamespace = "skill"
// IsSkillReference checks if a name refers to a skill (has skill/ prefix).
func IsSkillReference(name string) bool {
// Check for skill/ prefix (handles both "skill/foo" and "registry/skill/foo")
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
// skill/name or skill/name:tag
if len(parts) >= 1 && parts[0] == SkillNamespace {
return true
}
// namespace/skill/name (e.g., myuser/skill/calc) - not a skill ref
// registry/skill/name (e.g., registry.ollama.ai/skill/calc)
if len(parts) >= 2 && parts[1] == SkillNamespace {
return true
}
return false
}
// ParseSkillName parses a skill reference string into a model.Name.
// The Kind field is set to "skill".
// Examples:
// - "calculator" -> library/skill/calculator:latest
// - "myname/calculator" -> myname/skill/calculator:latest
// - "myname/skill/calculator:1.0.0" -> myname/skill/calculator:1.0.0
func ParseSkillName(name string) model.Name {
// Use the standard parser which now handles Kind
n := model.ParseName(name)
// If Kind wasn't set (old format without skill/), set it
if n.Kind == "" {
n.Kind = SkillNamespace
}
return n
}
// SkillDisplayName returns a user-friendly display name for a skill.
func SkillDisplayName(n model.Name) string {
return n.DisplayShortest()
}
// GetSkillManifestPath returns the path to the skill manifest file.
// Uses the 5-part structure: host/namespace/kind/model/tag
func GetSkillManifestPath(n model.Name) (string, error) {
if n.Model == "" {
return "", fmt.Errorf("skill name is required")
}
// Ensure Kind is set
if n.Kind == "" {
n.Kind = SkillNamespace
}
path := filepath.Join(
envconfig.Models(),
"manifests",
n.Filepath(),
)
return path, nil
}

View File

@@ -381,6 +381,28 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -396,11 +418,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -413,7 +435,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments map[string]any
Arguments templateArgs
}
// templateMessage is a template-compatible representation of api.Message
@@ -446,7 +468,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: tool.Function.Parameters.Properties.ToMap(),
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
},
},
}
@@ -468,7 +490,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
},
})
}

View File

@@ -613,3 +613,159 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -3,12 +3,13 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
)
func (c Capability) String() string {

View File

@@ -1,29 +1,5 @@
package model
// SkillRef represents a reference to a skill, either by local path or by registry digest.
type SkillRef struct {
// Name is the local path (for development) or registry name (e.g., "skill/calculator:1.0.0")
Name string `json:"name,omitempty"`
// Digest is the content-addressable digest of the skill blob (e.g., "sha256:abc123...")
Digest string `json:"digest,omitempty"`
}
// MCPRef represents a reference to an MCP (Model Context Protocol) server.
type MCPRef struct {
// Name is the identifier for the MCP server (used for tool namespacing)
Name string `json:"name,omitempty"`
// Digest is the content-addressable digest of the bundled MCP server blob
Digest string `json:"digest,omitempty"`
// Command is the executable to run (e.g., "uv", "node", "python3")
Command string `json:"command,omitempty"`
// Args are the arguments to pass to the command
Args []string `json:"args,omitempty"`
// Env is optional environment variables for the MCP server
Env map[string]string `json:"env,omitempty"`
// Type is the transport type (currently only "stdio" is supported)
Type string `json:"type,omitempty"`
}
// ConfigV2 represents the configuration metadata for a model.
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
@@ -44,12 +20,6 @@ type ConfigV2 struct {
EmbedLen int `json:"embedding_length,omitempty"`
BaseName string `json:"base_name,omitempty"`
// agent-specific fields
Skills []SkillRef `json:"skills,omitempty"`
MCPs []MCPRef `json:"mcps,omitempty"`
AgentType string `json:"agent_type,omitempty"`
Entrypoint string `json:"entrypoint,omitempty"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`

View File

@@ -59,7 +59,6 @@ type partKind int
const (
kindHost partKind = iota
kindNamespace
kindKind
kindModel
kindTag
kindDigest
@@ -71,8 +70,6 @@ func (k partKind) String() string {
return "host"
case kindNamespace:
return "namespace"
case kindKind:
return "kind"
case kindModel:
return "model"
case kindTag:
@@ -92,7 +89,6 @@ func (k partKind) String() string {
type Name struct {
Host string
Namespace string
Kind string // Optional: "skill", "agent", or empty for models
Model string
Tag string
}
@@ -101,27 +97,34 @@ type Name struct {
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { kind } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { namespace } "/" { kind } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// kind:
// pattern: "skill" | "agent" | "" (empty for models)
// length: [0, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// Most users should use [ParseName] instead, unless need to support
// different defaults than DefaultName.
@@ -133,13 +136,6 @@ func ParseName(s string) Name {
return Merge(ParseNameBare(s), DefaultName())
}
// ValidKinds are the allowed values for the Kind field
var ValidKinds = map[string]bool{
"skill": true,
"agent": true,
"mcp": true,
}
// ParseNameBare parses s as a name string and returns a Name. No merge with
// [DefaultName] is performed.
func ParseNameBare(s string) Name {
@@ -157,30 +153,6 @@ func ParseNameBare(s string) Name {
return n
}
s, n.Kind, promised = cutPromised(s, "/")
if !promised {
// Only 2 parts: namespace/model - what we parsed as Kind is actually Namespace
n.Namespace = n.Kind
n.Kind = ""
return n
}
// Check if what we parsed as Kind is actually a valid kind value
if !ValidKinds[n.Kind] {
// Not a valid kind - this is the old 3-part format: host/namespace/model
// Shift: Kind -> Namespace, s -> Host
n.Namespace = n.Kind
n.Kind = ""
scheme, host, ok := strings.Cut(s, "://")
if !ok {
host = scheme
}
n.Host = host
return n
}
// Valid kind found - continue parsing for namespace and optional host
s, n.Namespace, promised = cutPromised(s, "/")
if !promised {
n.Namespace = s
@@ -196,32 +168,20 @@ func ParseNameBare(s string) Name {
return n
}
// ParseNameFromFilepath parses a 4 or 5-part filepath as a Name. The parts are
// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
// expected to be in the form:
//
// { host } "/" { namespace } "/" { model } "/" { tag }
// { host } "/" { namespace } "/" { kind } "/" { model } "/" { tag }
func ParseNameFromFilepath(s string) (n Name) {
parts := strings.Split(s, string(filepath.Separator))
switch len(parts) {
case 4:
// Old format: host/namespace/model/tag
n.Host = parts[0]
n.Namespace = parts[1]
n.Model = parts[2]
n.Tag = parts[3]
case 5:
// New format: host/namespace/kind/model/tag
n.Host = parts[0]
n.Namespace = parts[1]
n.Kind = parts[2]
n.Model = parts[3]
n.Tag = parts[4]
default:
if len(parts) != 4 {
return Name{}
}
n.Host = parts[0]
n.Namespace = parts[1]
n.Model = parts[2]
n.Tag = parts[3]
if !n.IsFullyQualified() {
return Name{}
}
@@ -229,12 +189,11 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, kind, and tag parts of the two names,
// Merge merges the host, namespace, and tag parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Kind = cmp.Or(a.Kind, b.Kind)
a.Tag = cmp.Or(a.Tag, b.Tag)
return a
}
@@ -252,10 +211,6 @@ func (n Name) String() string {
b.WriteString(n.Namespace)
b.WriteByte('/')
}
if n.Kind != "" {
b.WriteString(n.Kind)
b.WriteByte('/')
}
b.WriteString(n.Model)
if n.Tag != "" {
b.WriteByte(':')
@@ -278,12 +233,6 @@ func (n Name) DisplayShortest() string {
sb.WriteByte('/')
}
// include kind if present
if n.Kind != "" {
sb.WriteString(n.Kind)
sb.WriteByte('/')
}
// always include model and tag
sb.WriteString(n.Model)
sb.WriteString(":")
@@ -307,23 +256,18 @@ func (n Name) IsValid() bool {
}
// IsFullyQualified returns true if all parts of the name are present and
// valid without the digest. Kind is optional and only validated if non-empty.
// valid without the digest.
func (n Name) IsFullyQualified() bool {
if !isValidPart(kindHost, n.Host) {
return false
parts := []string{
n.Host,
n.Namespace,
n.Model,
n.Tag,
}
if !isValidPart(kindNamespace, n.Namespace) {
return false
}
// Kind is optional - only validate if present
if n.Kind != "" && !isValidPart(kindKind, n.Kind) {
return false
}
if !isValidPart(kindModel, n.Model) {
return false
}
if !isValidPart(kindTag, n.Tag) {
return false
for i, part := range parts {
if !isValidPart(partKind(i), part) {
return false
}
}
return true
}
@@ -332,7 +276,6 @@ func (n Name) IsFullyQualified() bool {
// host to tag as a directory in the form:
//
// {host}/{namespace}/{model}/{tag}
// {host}/{namespace}/{kind}/{model}/{tag}
//
// It uses the system's filepath separator and ensures the path is clean.
//
@@ -342,15 +285,6 @@ func (n Name) Filepath() string {
if !n.IsFullyQualified() {
panic("illegal attempt to get filepath of invalid name")
}
if n.Kind != "" {
return filepath.Join(
n.Host,
n.Namespace,
n.Kind,
n.Model,
n.Tag,
)
}
return filepath.Join(
n.Host,
n.Namespace,
@@ -367,7 +301,6 @@ func (n Name) LogValue() slog.Value {
func (n Name) EqualFold(o Name) bool {
return strings.EqualFold(n.Host, o.Host) &&
strings.EqualFold(n.Namespace, o.Namespace) &&
strings.EqualFold(n.Kind, o.Kind) &&
strings.EqualFold(n.Model, o.Model) &&
strings.EqualFold(n.Tag, o.Tag)
}
@@ -384,11 +317,6 @@ func isValidLen(kind partKind, s string) bool {
}
func isValidPart(kind partKind, s string) bool {
// Kind must be one of the valid values
if kind == kindKind {
return ValidKinds[s]
}
if !isValidLen(kind, s) {
return false
}

24
x/README.md Normal file
View File

@@ -0,0 +1,24 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -4,6 +4,7 @@ package agent
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
@@ -32,10 +33,29 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Always allow",
"2. Allow for this session",
"3. Deny",
}
// toolDisplayNames maps internal tool names to human-readable display names.
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
}
// ToolDisplayName returns the human-readable display name for a tool.
func ToolDisplayName(toolName string) string {
if displayName, ok := toolDisplayNames[toolName]; ok {
return displayName
}
// Default: capitalize first letter and replace underscores with spaces
name := strings.ReplaceAll(toolName, "_", " ")
if len(name) > 0 {
return strings.ToUpper(name[:1]) + name[1:]
}
return toolName
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
@@ -179,6 +199,7 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
@@ -204,8 +225,8 @@ func extractBashPrefix(command string) string {
return ""
}
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
@@ -215,19 +236,49 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
continue
}
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
// Normalize to forward slashes for consistent cross-platform matching
arg = strings.ReplaceAll(arg, "\\", "/")
// Security: reject absolute paths
if path.IsAbs(arg) {
return "" // Absolute path - don't create prefix
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
// Normalize the path using stdlib path.Clean (resolves . and ..)
cleaned := path.Clean(arg)
// Security: reject if cleaned path escapes to parent directory
if strings.HasPrefix(cleaned, "..") {
return "" // Path escapes - don't create prefix
}
// Security: if original had "..", verify cleaned path didn't escape to sibling
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
if strings.Contains(arg, "..") {
origBase := strings.SplitN(arg, "/", 2)[0]
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
if origBase != cleanedBase {
return "" // Path escaped to sibling directory
}
}
// Check if arg ends with / (explicit directory)
isDir := strings.HasSuffix(arg, "/")
// Get the directory part
var dir string
if isDir {
dir = cleaned
} else {
dir = path.Dir(cleaned)
}
if dir == "." {
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
return fmt.Sprintf("%s:./", baseCmd)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
@@ -332,6 +383,8 @@ func AllowlistKey(toolName string, args map[string]any) string {
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
@@ -342,12 +395,20 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return true
}
// For bash commands, check prefix matches
// For bash commands, check prefix matches with hierarchical path support
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" && a.prefixes[prefix] {
return true
if prefix != "" {
// Check exact prefix match first
if a.prefixes[prefix] {
return true
}
// Check hierarchical match: if any stored prefix is a parent of current prefix
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
if a.matchesHierarchicalPrefix(prefix) {
return true
}
}
}
}
@@ -360,6 +421,40 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return false
}
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
// Split prefix into command and path parts (format: "cmd:path/")
colonIdx := strings.Index(currentPrefix, ":")
if colonIdx == -1 {
return false
}
currentCmd := currentPrefix[:colonIdx]
currentPath := currentPrefix[colonIdx+1:]
for storedPrefix := range a.prefixes {
storedColonIdx := strings.Index(storedPrefix, ":")
if storedColonIdx == -1 {
continue
}
storedCmd := storedPrefix[:storedColonIdx]
storedPath := storedPrefix[storedColonIdx+1:]
// Commands must match exactly
if currentCmd != storedCmd {
continue
}
// Check if current path starts with stored path (hierarchical match)
// e.g., "tools/subdir/" starts with "tools/"
if strings.HasPrefix(currentPath, storedPath) {
return true
}
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
@@ -399,16 +494,32 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
isWarning = isCommandOutsideCwd(cmd)
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
if prefix := extractBashPrefix(cmd); prefix != "" {
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
}
}
}
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -433,27 +544,29 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
displayName := ToolDisplayName(toolName)
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
}
// For web search, show query
// For web search, show query and internet notice
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s", query))
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
@@ -470,24 +583,28 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
}
// Get terminal size
@@ -647,7 +764,7 @@ func wrapText(text string, maxWidth int) []string {
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
@@ -657,86 +774,70 @@ func getHintLines(state *selectorState) []string {
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
toolLines := strings.Split(state.toolDisplay, "\n")
hintLines := getHintLines(state)
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
warningLines := 0
if state.isWarning {
warningLines = 1
warningLines = 2 // warning line + blank line after
}
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the complete selector box
// renderSelectorBox renders the selector (minimal, no box)
func renderSelectorBox(state *selectorState) {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
toolLines := strings.Split(state.toolDisplay, "\n")
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
// Draw warning line if needed
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
}
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
}
// Draw tool info
// Draw tool info (plain white)
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
}
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option - show with reason input beside it
if i == 2 {
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line before hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw hint (may be multiple lines)
// Draw hint (dark grey)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
@@ -748,50 +849,39 @@ func renderSelectorBox(state *selectorState) {
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + numOptions
// (hint lines - 1) + 1 (blank line) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option
if i == 2 {
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -805,36 +895,26 @@ func updateSelectorOptions(state *selectorState) {
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -858,11 +938,10 @@ func clearSelectorBox(state *selectorState) {
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
var input string
fmt.Scanln(&input)
@@ -905,19 +984,16 @@ func (a *ApprovalManager) AllowedTools() []string {
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var status string
var icon string
var label string
displayName := ToolDisplayName(toolName)
switch result.Decision {
case ApprovalOnce:
status = "Approved"
icon = "\033[32m✓\033[0m"
label = "Approved"
case ApprovalAlways:
status = "Always allowed"
icon = "\033[32m✓\033[0m"
label = "Always allowed"
case ApprovalDeny:
status = "Denied"
icon = "\033[31m✗\033[0m"
label = "Denied"
}
// Format based on tool type
@@ -927,7 +1003,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
}
}
@@ -937,11 +1013,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
}
}
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
}
// FormatDenyResult returns the tool result message when a tool is denied.
@@ -951,3 +1027,78 @@ func FormatDenyResult(toolName string, reason string) string {
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
// Returns true for Yes, false for No.
func PromptYesNo(question string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
selected := 0 // 0 = Yes, 1 = No
options := []string{"Yes", "No"}
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "%s ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
}
}
}
renderYesNo()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return false, err
}
if n == 1 {
switch buf[0] {
case 'y', 'Y':
selected = 0
renderYesNo()
case 'n', 'N':
selected = 1
renderYesNo()
case '\r', '\n': // Enter
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
return selected == 0, nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\r\033[K")
return false, nil
case 27: // Escape - could be arrow key
// Read more bytes for arrow keys
continue
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
// Arrow keys
switch buf[2] {
case 'D': // Left
if selected > 0 {
selected--
}
renderYesNo()
case 'C': // Right
if selected < len(options)-1 {
selected++
}
renderYesNo()
}
}
}
}

View File

@@ -151,6 +151,27 @@ func TestExtractBashPrefix(t *testing.T) {
command: "head -n 100",
expected: "",
},
// Path traversal security tests
{
name: "path traversal - parent escape",
command: "cat tools/../../etc/passwd",
expected: "", // Should NOT create a prefix - path escapes
},
{
name: "path traversal - deep escape",
command: "cat tools/a/b/../../../etc/passwd",
expected: "", // Normalizes to "../etc/passwd" - escapes
},
{
name: "path traversal - absolute path",
command: "cat /etc/passwd",
expected: "", // Absolute paths should not create prefix
},
{
name: "path with safe dotdot - normalized",
command: "cat tools/subdir/../file.go",
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
},
}
for _, tt := range tests {
@@ -164,6 +185,34 @@ func TestExtractBashPrefix(t *testing.T) {
}
}
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Path traversal attack: should NOT be allowed
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
t.Error("SECURITY: path traversal attack should NOT be allowed")
}
// Another traversal variant
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
t.Error("SECURITY: deep path traversal should NOT be allowed")
}
// Valid subdirectory access should still work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed")
}
// Safe ".." that normalizes to within allowed directory should work
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
@@ -186,6 +235,119 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow subdirectories (hierarchical matching)
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
}
// Should allow deeply nested subdirectories
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
}
// Should still allow same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
t.Error("expected cat tools/another.go to be allowed")
}
// Should NOT allow different base directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should NOT allow different command even in subdirectory
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
}
// Should NOT allow similar but different directory name
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
am := NewApprovalManager()
// Allow with forward slashes (Unix-style)
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should work with backslashes too (Windows-style) - normalized internally
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
}
// Mixed slashes should also work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
}
}
func TestMatchesHierarchicalPrefix(t *testing.T) {
am := NewApprovalManager()
// Add prefix for "cat:tools/"
am.prefixes["cat:tools/"] = true
tests := []struct {
name string
prefix string
expected bool
}{
{
name: "exact match",
prefix: "cat:tools/",
expected: true, // exact match also passes HasPrefix - caller handles exact match first
},
{
name: "subdirectory",
prefix: "cat:tools/subdir/",
expected: true,
},
{
name: "deeply nested",
prefix: "cat:tools/a/b/c/",
expected: true,
},
{
name: "different base directory",
prefix: "cat:src/",
expected: false,
},
{
name: "different command same path",
prefix: "ls:tools/",
expected: false,
},
{
name: "similar directory name",
prefix: "cat:toolsbin/",
expected: false,
},
{
name: "invalid prefix format",
prefix: "cattools",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := am.matchesHierarchicalPrefix(tt.prefix)
if result != tt.expected {
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
tt.prefix, result, tt.expected)
}
})
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,10 +6,12 @@ import (
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -22,6 +24,101 @@ import (
"github.com/ollama/ollama/x/tools"
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
localModelTokenLimit = 4000
// defaultTokenLimit is the token limit for cloud/remote models.
defaultTokenLimit = 10000
// charsPerToken is a rough estimate of characters per token.
// TODO: Estimate tokens more accurately using tokenizer if available
charsPerToken = 4
)
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
}
// isLocalServer checks if connecting to a local Ollama server.
// TODO: Could also check other indicators of local vs cloud server
func isLocalServer() bool {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return true // Default is localhost:11434
}
// Parse the URL to check host
parsed, err := url.Parse(host)
if err != nil {
return true // If can't parse, assume local
}
hostname := parsed.Hostname()
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
var tokenLimit int
if isLocalModel(modelName) && isLocalServer() {
tokenLimit = localModelTokenLimit
} else {
tokenLimit = defaultTokenLimit
}
maxChars := tokenLimit * charsPerToken
if len(output) > maxChars {
return output[:maxChars] + "\n... (output truncated)"
}
return output
}
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
func waitForOllamaSignin(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\n")
return ctx.Err()
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
return nil
}
// Still waiting, show dot
fmt.Fprintf(os.Stderr, ".")
}
}
}
return err
}
return nil
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -37,6 +134,9 @@ type RunOptions struct {
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
// YoloMode skips all tool approval prompts
YoloMode bool
}
// Chat runs an agent chat loop with tool support.
@@ -77,6 +177,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
role := "assistant"
messages := opts.Messages
@@ -159,6 +260,58 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
continue // Retry the loop
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
}
// Check for 500 errors (often tool parsing failures) - inform the model
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
consecutiveErrors++
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
if assistantContent == "" {
assistantContent = "(empty response)"
}
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
messages = append(messages,
api.Message{Role: "user", Content: errorMsg},
)
// Reset state and retry
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
continue
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
@@ -168,6 +321,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, err
}
// Reset consecutive error counter on success
consecutiveErrors = 0
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
@@ -197,8 +353,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
@@ -208,15 +364,21 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check if command is auto-allowed (safe command)
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
// TODO(parthsareen): re-enable with tighter scoped allowlist
// if agent.IsAutoAllowed(cmd) {
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
// skipApproval = true
// }
}
}
// Check approval (uses prefix matching for bash commands)
if !skipApproval && !approval.IsAllowed(toolName, args) {
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
@@ -244,13 +406,30 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
}
}
}
}
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
@@ -258,6 +437,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
})
continue
}
toolSuccess:
// Display tool output (truncated for display)
if toolResult != "" {
@@ -269,9 +449,12 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResult,
Content: toolResultForLLM,
ToolCallID: call.ID,
})
}
@@ -317,17 +500,18 @@ func truncateUTF8(s string, limit int) string {
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
displayName := agent.ToolDisplayName(toolName)
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
}
}
return toolName
return displayName
}
// Helper types and functions for display
@@ -449,7 +633,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -466,7 +651,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
supportsTools = false
}
@@ -474,14 +659,17 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
@@ -524,6 +712,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
@@ -546,6 +737,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
}
assistant, err := Chat(cmd.Context(), opts)

180
x/cmd/run_test.go Normal file
View File

@@ -0,0 +1,180 @@
package cmd
import (
"testing"
)
func TestIsLocalModel(t *testing.T) {
tests := []struct {
name string
modelName string
expected bool
}{
{
name: "local model without suffix",
modelName: "llama3.2",
expected: true,
},
{
name: "local model with version",
modelName: "qwen2.5:7b",
expected: true,
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
expected: false,
},
{
name: "empty model name",
modelName: "",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalModel(tt.modelName)
if result != tt.expected {
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
}
})
}
}
func TestIsLocalServer(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{
name: "empty host (default)",
host: "",
expected: true,
},
{
name: "localhost",
host: "http://localhost:11434",
expected: true,
},
{
name: "127.0.0.1",
host: "http://127.0.0.1:11434",
expected: true,
},
{
name: "custom port on localhost",
host: "http://localhost:8080",
expected: true, // localhost is always considered local
},
{
name: "remote host",
host: "http://ollama.example.com:11434",
expected: true, // has :11434
},
{
name: "remote host different port",
host: "http://ollama.example.com:8080",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := isLocalServer()
if result != tt.expected {
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
}
})
}
}
func TestTruncateToolOutput(t *testing.T) {
// Create outputs of different sizes
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
for i := range localLimitOutput {
localLimitOutput[i] = 'a'
}
for i := range defaultLimitOutput {
defaultLimitOutput[i] = 'b'
}
tests := []struct {
name string
output string
modelName string
host string
shouldTrim bool
expectedLimit int
}{
{
name: "short output local model",
output: "hello world",
modelName: "llama3.2",
host: "",
shouldTrim: false,
expectedLimit: localModelTokenLimit,
},
{
name: "long output local model - trimmed at 4k",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "",
shouldTrim: true,
expectedLimit: localModelTokenLimit,
},
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,
},
{
name: "long output remote server - uses 10k limit",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "http://remote.example.com:8080",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := truncateToolOutput(tt.output, tt.modelName)
if tt.shouldTrim {
maxLen := tt.expectedLimit * charsPerToken
if len(result) > maxLen+50 { // +50 for the truncation message
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
}
if result == tt.output {
t.Error("expected output to be truncated but it wasn't")
}
} else {
if result != tt.output {
t.Error("expected output to not be truncated")
}
}
})
}
}

38
x/imagegen/.gitignore vendored Normal file
View File

@@ -0,0 +1,38 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

250
x/imagegen/README.md Normal file
View File

@@ -0,0 +1,250 @@
# Image Generation in Ollama (Experimental)
Generate images from text prompts using local AI models.
## Quick Start
```bash
# Run with a prompt
ollama run z-image "a sunset over mountains"
Generating: step 30/30
Image saved to: /tmp/ollama-image-1704067200.png
```
On macOS, the generated image will automatically open in Preview.
## Supported Models
| Model | VRAM Required | Notes |
|-------|---------------|-------|
| z-image | ~12GB | Based on Flux architecture |
## CLI Usage
```bash
# Generate an image
ollama run z-image "a cat playing piano"
# Check if model is running
ollama ps
# Stop the model
ollama stop z-image
```
## API
### OpenAI-Compatible Endpoint
```bash
POST /v1/images/generations
```
**Request:**
```json
{
"model": "z-image",
"prompt": "a sunset over mountains",
"size": "1024x1024",
"response_format": "b64_json"
}
```
**Response:**
```json
{
"created": 1704067200,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
### Example: cURL
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}'
```
### Example: Save to File
```bash
curl -s http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
```
### Streaming Progress
Enable streaming to receive progress updates via SSE:
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
```
Events:
```
event: progress
data: {"step": 1, "total": 30}
event: progress
data: {"step": 2, "total": 30}
...
event: done
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
```
## Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| model | string | required | Model name |
| prompt | string | required | Text description of image |
| size | string | "1024x1024" | Image dimensions (WxH) |
| n | int | 1 | Number of images (currently only 1 supported) |
| response_format | string | "b64_json" | "b64_json" or "url" |
| stream | bool | false | Enable progress streaming |
## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
- Sufficient VRAM (see model table above)
- Ollama built with MLX support
## Limitations
- macOS only (uses MLX backend)
- Single image per request
- Fixed step count (30 steps)
- Modelfiles not yet supported (use `ollama create` from model directory)
---
# Tensor Model Storage Format
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
## Manifest Structure
The manifest follows the standard ollama format with tensor-specific layer metadata:
```json
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:...", "size": 1234 },
"layers": [
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:25b36eed...",
"size": 49807448,
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
"dtype": "BF16",
"shape": [2560, 9728]
},
{
"mediaType": "application/vnd.ollama.image.json",
"digest": "sha256:abc123...",
"size": 512,
"name": "text_encoder/config.json"
}
]
}
```
Each tensor layer includes:
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
- `dtype`: Data type (BF16, F32, etc.)
- `shape`: Tensor dimensions
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
## Blob Format
Each tensor blob is a minimal safetensors file:
```
[8 bytes: header size (uint64 LE)]
[~80 bytes: JSON header, padded to 8-byte alignment]
[N bytes: raw tensor data]
```
Header contains a single tensor named `"data"`:
```json
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
```
## Why Include the Header?
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
1. **Uses mmap** - Maps file directly into memory, no copies
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
3. **No custom code** - Standard MLX API, battle-tested
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
## Why Per-Tensor Blobs?
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
```
~/.ollama/models/
blobs/
sha256-25b36eed... <- shared by both models
sha256-abc123...
manifests/
library/model-a/latest <- references sha256-25b36eed
library/model-b/latest <- references sha256-25b36eed
```
## Import Flow
```
cd ./weights/Z-Image-Turbo
ollama create z-image
1. Scan component directories (text_encoder/, transformer/, vae/)
2. For each .safetensors file:
- Extract individual tensors
- Wrap each in minimal safetensors format (88B header + data)
- Write to blob store (SHA256 content-addressed)
- Add layer entry to manifest with path-style name
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

231
x/imagegen/api/handler.go Normal file
View File

@@ -0,0 +1,231 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

31
x/imagegen/api/types.go Normal file
View File

@@ -0,0 +1,31 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

156
x/imagegen/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,156 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

164
x/imagegen/cache/step.go vendored Normal file
View File

@@ -0,0 +1,164 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}

197
x/imagegen/cache/teacache.go vendored Normal file
View File

@@ -0,0 +1,197 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

541
x/imagegen/cli.go Normal file
View File

@@ -0,0 +1,541 @@
// cli.go provides CLI commands for image generation models.
//
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
// Currently these are separate to keep experimental code isolated.
package imagegen
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
)
// ImageGenOptions holds options for image generation.
// These can be set via environment variables or interactive commands.
type ImageGenOptions struct {
Width int
Height int
Steps int
Seed int
NegativePrompt string
}
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 9,
Seed: 0, // 0 means random
}
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
cmd.Flags().MarkHidden("seed")
cmd.Flags().MarkHidden("negative")
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
opts.Width = v
}
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
opts.Height = v
}
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
opts.Steps = v
}
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
opts.Seed = v
}
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
opts.NegativePrompt = v
}
}
if interactive {
return runInteractive(cmd, name, keepAlive, opts)
}
// One-shot generation
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
}
// generateImageWithOptions generates an image with the given options.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
})
p.Stop()
if err != nil {
return err
}
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
return nil
}
// runInteractive runs an interactive REPL for image generation.
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
continue
case err != nil:
return err
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Handle commands
switch {
case strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts)
continue
case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
case strings.HasPrefix(line, "/show"):
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
})
p.Stop()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
}
}
// sanitizeFilename removes characters that aren't safe for filenames.
func sanitizeFilename(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
// Remove any character that's not alphanumeric or hyphen
var result strings.Builder
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
fmt.Fprintln(os.Stderr)
}
// printCurrentSettings prints the current image generation settings.
func printCurrentSettings(opts ImageGenOptions) {
fmt.Fprintf(os.Stderr, "Current settings:\n")
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
if opts.NegativePrompt != "" {
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
}
fmt.Fprintln(os.Stderr)
}
// handleSetCommand handles /set commands to change options.
func handleSetCommand(args string, opts *ImageGenOptions) error {
parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 {
return fmt.Errorf("usage: /set <option> <value>")
}
key := strings.ToLower(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "width", "w":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("width must be a positive integer")
}
opts.Width = v
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
case "height", "h":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("height must be a positive integer")
}
opts.Height = v
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
case "steps", "s":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("steps must be a positive integer")
}
opts.Steps = v
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
case "seed":
v, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("seed must be an integer")
}
opts.Seed = v
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
case "negative", "neg", "n":
opts.NegativePrompt = value
if value == "" {
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
} else {
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
}
default:
return fmt.Errorf("unknown option: %s (try /help)", key)
}
return nil
}
// displayImageInTerminal attempts to render an image inline in the terminal.
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
// Returns true if the image was displayed, false otherwise.
func displayImageInTerminal(imagePath string) bool {
// Check if terminal supports inline images
termProgram := os.Getenv("TERM_PROGRAM")
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
weztermPane := os.Getenv("WEZTERM_PANE")
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
// Read the image file
data, err := os.ReadFile(imagePath)
if err != nil {
return false
}
encoded := base64.StdEncoding.EncodeToString(data)
switch {
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
// iTerm2/WezTerm inline image protocol
// ESC ] 1337 ; File = [arguments] : base64 BEL
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
return true
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
// Kitty graphics protocol (also used by Ghostty)
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[i:end]
if i == 0 {
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
more := 1
if end >= len(encoded) {
more = 0
}
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
} else if end >= len(encoded) {
// Last chunk
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
} else {
// Middle chunk
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
}
}
fmt.Println()
return true
default:
return false
}
}

190
x/imagegen/client/create.go Normal file
View File

@@ -0,0 +1,190 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -0,0 +1,120 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -0,0 +1,18 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -0,0 +1,35 @@
# MLX Engine
Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
## Image Generation
```bash
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
```
Options:
- `-width`, `-height` - image dimensions (default 1024x1024)
- `-steps` - denoising steps (default 9)
- `-seed` - random seed (default 42)

View File

@@ -0,0 +1,359 @@
//go:build mlx
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}

View File

@@ -0,0 +1,89 @@
//go:build mlx
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -0,0 +1,293 @@
//go:build mlx
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}

View File

@@ -0,0 +1,49 @@
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}

213
x/imagegen/create.go Normal file
View File

@@ -0,0 +1,213 @@
package imagegen
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
continue
}
// Find all safetensors files in this component
entries, err := os.ReadDir(componentDir)
if err != nil {
return fmt.Errorf("failed to read %s: %w", component, err)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(componentDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
}
// Import config files
configFiles := []string{
"model_index.json",
"text_encoder/config.json",
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
}
for _, cfgPath := range configFiles {
fullPath := filepath.Join(modelDir, cfgPath)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
continue
}
fn(fmt.Sprintf("importing config %s", cfgPath))
var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
}
// Rename _class_name to architecture, remove diffusers-specific fields
if className, ok := cfg["_class_name"]; ok {
cfg["architecture"] = className
delete(cfg, "_class_name")
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
}
r = bytes.NewReader(data)
} else {
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
defer f.Close()
r = f
}
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use model_index.json as the config layer
if cfgPath == "model_index.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("model_index.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

Some files were not shown because too many files have changed in this diff Show More