Compare commits

..

11 Commits

Author SHA1 Message Date
ParthSareen
0c2c2b8de9 fixes from rebase 2026-01-07 15:44:19 -08:00
ParthSareen
5e23c4f2f7 fix: agent loop message handling and cloud model inheritance
- Fix tool result messages losing ToolName and ToolCallID fields
- Include Thinking in intermediate assistant messages during tool loops
- Inherit capabilities, remote config, and model family from base model
  when creating agents (fixes "does not support generate" for cloud models)
- Add tests for tool message construction and message stitching

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 12:15:18 -08:00
ParthSareen
5c0caaff86 agents: add MCP server support and ENTRYPOINT command
MCP (Model Context Protocol) support:
- Add MCPRef type for agent MCP server references
- Parse MCP command in Agentfiles (MCP name command [args...])
- Load and manage MCP servers with mcpManager
- Implement agentic loop for multi-turn tool execution
- Add /mcp REPL commands (add, remove, disable, enable)
- Add 'ollama mcp' CLI commands for global config management
- Support both model-bundled and global (~/.ollama/mcp.json) MCPs
- Display MCPs in 'ollama show' output

ENTRYPOINT support:
- Add ENTRYPOINT command to Agentfiles for custom runtimes
- Allow agents without FROM when ENTRYPOINT is specified
- Execute entrypoint as subprocess with stdin/stdout connected
- Support $PROMPT placeholder for prompt insertion control
- Hide Model section in 'ollama show' for entrypoint-only agents
- Pass user prompt as argument to entrypoint command
2026-01-06 12:15:18 -08:00
ParthSareen
e28ee8524d skills: add registry reference check and working directory env var
- Add check for registry references without digest in loadSkillsFromRefs
- Fix IsLocalSkillPath to not treat registry refs as local paths
- Inject OLLAMA_WORKING_DIR env var so skill scripts can access the
  directory where 'ollama run' was called from
2026-01-06 12:12:27 -08:00
ParthSareen
623e539a09 docs: add skills documentation
Add comprehensive documentation for the skills feature:

- Quick start guide for creating skills
- SKILL.md structure and frontmatter
- Skill reference formats (local, library, user)
- CLI commands (push, pull, list, show, rm)
- Dynamic skills in interactive chat
- Storage layout
- Security considerations
- Future roadmap
2026-01-06 12:12:27 -08:00
ParthSareen
51911a5f6f cmd: add skill CLI and REPL commands
Add skill management commands and interactive REPL support:

CLI commands (cmd/skill_cmd.go):
  ollama skill push NAME PATH  - Push skill to registry
  ollama skill pull NAME       - Pull skill from registry
  ollama skill list            - List installed skills
  ollama skill show NAME       - Show skill details
  ollama skill rm NAME         - Remove a skill

Skill loading (cmd/skills.go):
  - Load skills from model manifests
  - Parse SKILL.md frontmatter for metadata
  - Inject skill instructions into system prompt
  - Provide run_skill_script tool for script execution

Interactive mode (cmd/interactive.go):
  /skills              - Show available skills
  /skill add PATH      - Add skill from local path
  /skill remove NAME   - Remove skill from session
  /skill list          - List session skills
2026-01-06 12:12:27 -08:00
ParthSareen
2c2354e980 parser: add SKILL command for Agentfiles
Add SKILL command to the Modelfile/Agentfile parser.

Supports both local paths and registry references:
  SKILL ./path/to/skill       # Local skill bundled with agent
  SKILL skill/calc:1.0.0      # Registry skill reference
  SKILL alice/skill/calc:1.0  # User skill from registry
2026-01-06 12:12:27 -08:00
ParthSareen
ce6b19d8be api,types: add skill types and configuration
Add skill-related types to the API and configuration:

- api/types.go: Skill reference types for API requests/responses
- types/model/config.go: Skill configuration in model config
- envconfig/config.go: Environment configuration for skills
2026-01-06 12:12:27 -08:00
ParthSareen
1de00fada0 server: add skill layer support
Add support for skill layers in model manifests:

- server/skill.go: New file with skill extraction and packaging
  - GetSkillsPath: Returns path to extracted skills cache
  - ExtractSkillBlob: Extracts skill tar.gz to cache
  - CreateSkillLayer: Creates skill blob from directory
  - ParseSkillName/GetSkillManifestPath: Skill name handling

- server/images.go: Extract skill layers on pull
- server/create.go: Create skill layers from SKILL directives
- server/routes.go: Skill-related route handling

Skills are stored as gzipped tar archives with MediaType
"application/vnd.ollama.image.skill".
2026-01-06 12:12:27 -08:00
ParthSareen
7ecae75c4c server: add Kind field to ModelPath for 5-part naming
Updates ModelPath struct and parsing to support the Kind field,
enabling skills and agents to use the 5-part naming structure.

- ParseModelPath detects valid kinds (skill, agent)
- GetNamespaceRepository includes kind in path
- GetManifestPath returns correct 5-part filepath
- GetFullTagname/GetShortTagname include kind when present
2026-01-06 12:12:27 -08:00
ParthSareen
ad5c276cf6 types: add Kind field to model.Name for 5-part naming
Extends the model name structure from 4-part to 5-part:
  host/namespace/kind/model:tag

The Kind field is optional and supports:
- "skill" for skill packages
- "agent" for agent packages (future)
- empty for regular models

Parser detects valid kinds to distinguish between old format
(host/namespace/model) and new format (host/namespace/kind/model).
2026-01-06 12:12:27 -08:00
198 changed files with 4446 additions and 38890 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -68,7 +68,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -393,13 +392,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -532,7 +531,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -163,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,36 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -181,7 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -200,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -19,6 +19,12 @@ import (
"github.com/ollama/ollama/types/model"
)
// SkillRef is an alias for model.SkillRef representing a skill reference.
type SkillRef = model.SkillRef
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
type MCPRef = model.MCPRef
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
@@ -690,6 +696,18 @@ type CreateRequest struct {
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Skills is a list of skill references for the agent (local paths or registry refs)
Skills []SkillRef `json:"skills,omitempty"`
// MCPs is a list of MCP server references for the agent
MCPs []MCPRef `json:"mcps,omitempty"`
// AgentType defines the type of agent (e.g., "conversational", "task-based")
AgentType string `json:"agent_type,omitempty"`
// Entrypoint specifies an external command to run instead of the built-in chat loop
Entrypoint string `json:"entrypoint,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -741,6 +759,10 @@ type ShowResponse struct {
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
Skills []SkillRef `json:"skills,omitempty"`
MCPs []MCPRef `json:"mcps,omitempty"`
AgentType string `json:"agent_type,omitempty"`
Entrypoint string `json:"entrypoint,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].

402
cmd/agent_loop_test.go Normal file
View File

@@ -0,0 +1,402 @@
package cmd
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// TestToolMessage verifies that tool messages are constructed correctly
// with ToolName and ToolCallID preserved from the tool call.
func TestToolMessage(t *testing.T) {
tests := []struct {
name string
call api.ToolCall
content string
expected api.Message
}{
{
name: "basic tool message with ID",
call: api.ToolCall{
ID: "call_abc123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
content: "Sunny, 22°C",
expected: api.Message{
Role: "tool",
Content: "Sunny, 22°C",
ToolName: "get_weather",
ToolCallID: "call_abc123",
},
},
{
name: "tool message without ID",
call: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: api.ToolCallFunctionArguments{
"expression": "2+2",
},
},
},
content: "4",
expected: api.Message{
Role: "tool",
Content: "4",
ToolName: "calculate",
// ToolCallID should be empty when call.ID is empty
},
},
{
name: "MCP tool message",
call: api.ToolCall{
ID: "call_mcp123",
Function: api.ToolCallFunction{
Name: "mcp_websearch_search",
Arguments: api.ToolCallFunctionArguments{
"query": "ollama agents",
},
},
},
content: "Found 10 results",
expected: api.Message{
Role: "tool",
Content: "Found 10 results",
ToolName: "mcp_websearch_search",
ToolCallID: "call_mcp123",
},
},
{
name: "skill tool message",
call: api.ToolCall{
ID: "call_skill456",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "calculator",
"command": "python scripts/calc.py 2+2",
},
},
},
content: "Result: 4",
expected: api.Message{
Role: "tool",
Content: "Result: 4",
ToolName: "run_skill_script",
ToolCallID: "call_skill456",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toolMessage(tt.call, tt.content)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("toolMessage() mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestAssistantMessageWithThinking verifies that assistant messages
// in the tool loop should include thinking content.
func TestAssistantMessageConstruction(t *testing.T) {
tests := []struct {
name string
content string
thinking string
toolCalls []api.ToolCall
expectedMsg api.Message
}{
{
name: "assistant with thinking and tool calls",
content: "",
thinking: "I need to check the weather for Paris.",
toolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for Paris.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
},
},
{
name: "assistant with content, thinking, and tool calls",
content: "Let me check that for you.",
thinking: "User wants weather info.",
toolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "Let me check that for you.",
Thinking: "User wants weather info.",
ToolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
},
},
{
name: "assistant with multiple tool calls",
content: "",
thinking: "I'll check both cities.",
toolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I'll check both cities.",
ToolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the assistant message construction as done in chat()
assistantMsg := api.Message{
Role: "assistant",
Content: tt.content,
Thinking: tt.thinking,
ToolCalls: tt.toolCalls,
}
if diff := cmp.Diff(tt.expectedMsg, assistantMsg); diff != "" {
t.Errorf("assistant message mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestMessageStitchingOrder verifies that messages in a tool loop
// are stitched in the correct order:
// 1. User message
// 2. Assistant message with tool calls (and thinking)
// 3. Tool result messages (one per tool call, in order)
// 4. Next assistant response
func TestMessageStitchingOrder(t *testing.T) {
// Simulate a complete tool loop conversation
messages := []api.Message{
// Initial user message
{Role: "user", Content: "What's the weather in Paris and London?"},
// Assistant's first response with tool calls
{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for both cities.",
ToolCalls: []api.ToolCall{
{ID: "call_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
{ID: "call_2", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
},
},
// Tool results (in order matching tool calls)
{Role: "tool", Content: "Sunny, 22°C", ToolName: "get_weather", ToolCallID: "call_1"},
{Role: "tool", Content: "Rainy, 15°C", ToolName: "get_weather", ToolCallID: "call_2"},
// Final assistant response
{Role: "assistant", Content: "Paris is sunny at 22°C, and London is rainy at 15°C.", Thinking: "Got the data, now summarizing."},
}
// Verify structure
expectedRoles := []string{"user", "assistant", "tool", "tool", "assistant"}
for i, msg := range messages {
if msg.Role != expectedRoles[i] {
t.Errorf("message %d: expected role %q, got %q", i, expectedRoles[i], msg.Role)
}
}
// Verify tool results match tool calls in order
assistantWithTools := messages[1]
toolResults := []api.Message{messages[2], messages[3]}
if len(toolResults) != len(assistantWithTools.ToolCalls) {
t.Errorf("expected %d tool results for %d tool calls", len(assistantWithTools.ToolCalls), len(toolResults))
}
for i, result := range toolResults {
expectedToolCallID := assistantWithTools.ToolCalls[i].ID
if result.ToolCallID != expectedToolCallID {
t.Errorf("tool result %d: expected ToolCallID %q, got %q", i, expectedToolCallID, result.ToolCallID)
}
expectedToolName := assistantWithTools.ToolCalls[i].Function.Name
if result.ToolName != expectedToolName {
t.Errorf("tool result %d: expected ToolName %q, got %q", i, expectedToolName, result.ToolName)
}
}
// Verify thinking is present in assistant messages
if messages[1].Thinking == "" {
t.Error("first assistant message should have thinking content")
}
if messages[4].Thinking == "" {
t.Error("final assistant message should have thinking content")
}
}
// TestMultiTurnToolLoop verifies message stitching across multiple
// tool call iterations.
func TestMultiTurnToolLoop(t *testing.T) {
messages := []api.Message{
{Role: "user", Content: "What's 2+2 and also what's the weather in Paris?"},
// First tool call: calculate
{
Role: "assistant",
Thinking: "I'll start with the calculation.",
ToolCalls: []api.ToolCall{
{ID: "calc_1", Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expr": "2+2"}}},
},
},
{Role: "tool", Content: "4", ToolName: "calculate", ToolCallID: "calc_1"},
// Second tool call: weather
{
Role: "assistant",
Thinking: "Got the calculation. Now checking weather.",
ToolCalls: []api.ToolCall{
{ID: "weather_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
},
},
{Role: "tool", Content: "Sunny, 20°C", ToolName: "get_weather", ToolCallID: "weather_1"},
// Final response
{Role: "assistant", Content: "2+2 equals 4, and Paris is sunny at 20°C."},
}
// Count message types
roleCounts := map[string]int{}
for _, msg := range messages {
roleCounts[msg.Role]++
}
if roleCounts["user"] != 1 {
t.Errorf("expected 1 user message, got %d", roleCounts["user"])
}
if roleCounts["assistant"] != 3 {
t.Errorf("expected 3 assistant messages, got %d", roleCounts["assistant"])
}
if roleCounts["tool"] != 2 {
t.Errorf("expected 2 tool messages, got %d", roleCounts["tool"])
}
// Verify each tool message follows an assistant with matching tool call
for i, msg := range messages {
if msg.Role == "tool" {
// Find preceding assistant message with tool calls
var precedingAssistant *api.Message
for j := i - 1; j >= 0; j-- {
if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 {
precedingAssistant = &messages[j]
break
}
}
if precedingAssistant == nil {
t.Errorf("tool message at index %d has no preceding assistant with tool calls", i)
continue
}
// Verify tool result matches one of the tool calls
found := false
for _, tc := range precedingAssistant.ToolCalls {
if tc.ID == msg.ToolCallID {
found = true
break
}
}
if !found {
t.Errorf("tool message at index %d has ToolCallID %q not found in preceding tool calls", i, msg.ToolCallID)
}
}
}
}
// TestSkillCatalogRunToolCallPreservesFields tests that skill catalog
// returns tool messages with correct fields.
func TestSkillCatalogToolMessageFields(t *testing.T) {
// Create a minimal test for toolMessage function
call := api.ToolCall{
ID: "test_id_123",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "test-skill",
"command": "echo hello",
},
},
}
msg := toolMessage(call, "hello")
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.Content != "hello" {
t.Errorf("expected content 'hello', got %q", msg.Content)
}
if msg.ToolName != "run_skill_script" {
t.Errorf("expected ToolName 'run_skill_script', got %q", msg.ToolName)
}
if msg.ToolCallID != "test_id_123" {
t.Errorf("expected ToolCallID 'test_id_123', got %q", msg.ToolCallID)
}
}

View File

@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -46,8 +47,6 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -98,11 +97,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -464,7 +458,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -503,6 +496,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
// Check if this is an agent
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
if isAgent {
opts.IsAgent = true
opts.AgentType = info.AgentType
opts.Skills = info.Skills
opts.MCPs = info.MCPs
opts.Entrypoint = info.Entrypoint
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -526,17 +529,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
// If agent has entrypoint, run it instead of chat loop
if opts.Entrypoint != "" {
return runEntrypoint(cmd, opts)
}
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -564,16 +562,69 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
// Use experimental agent loop with
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
}
return generateInteractive(cmd, opts)
}
// For agents, use chat API even in non-interactive mode to support tools
if opts.IsAgent {
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
_, err := chat(cmd, opts)
return err
}
return generate(cmd, opts)
}
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
entrypoint := opts.Entrypoint
// Check if entrypoint contains $PROMPT placeholder
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
if hasPlaceholder && opts.Prompt != "" {
// Replace $PROMPT with the actual prompt
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
} else if hasPlaceholder {
// No prompt provided but placeholder exists - remove placeholder
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
}
// Parse entrypoint into command and args
parts := strings.Fields(entrypoint)
if len(parts) == 0 {
return fmt.Errorf("empty entrypoint")
}
command := parts[0]
args := parts[1:]
// If user provided a prompt and no placeholder was used, append it as argument
if opts.Prompt != "" && !hasPlaceholder {
args = append(args, opts.Prompt)
}
// Look up command in PATH
execPath, err := exec.LookPath(command)
if err != nil {
return fmt.Errorf("entrypoint command not found: %s", command)
}
// Create subprocess
proc := exec.Command(execPath, args...)
proc.Stdin = os.Stdin
proc.Stdout = os.Stdout
proc.Stderr = os.Stderr
// Run and wait
return proc.Run()
}
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -672,11 +723,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -937,47 +984,96 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
fmt.Fprintln(w)
}
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
// Only show Model section if there's actual model info (not for entrypoint-only agents)
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
if hasModelInfo {
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
}
// Display agent information if this is an agent
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
tableRender("Agent", func() (rows [][]string) {
if resp.AgentType != "" {
rows = append(rows, []string{"", "type", resp.AgentType})
}
if resp.Entrypoint != "" {
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
}
if len(resp.Skills) > 0 {
for i, skill := range resp.Skills {
label := "skill"
if i > 0 {
label = ""
}
// Show skill name or digest
skillDisplay := skill.Name
if skillDisplay == "" && skill.Digest != "" {
skillDisplay = skill.Digest[:12] + "..."
}
rows = append(rows, []string{"", label, skillDisplay})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
if len(resp.MCPs) > 0 {
for i, mcp := range resp.MCPs {
label := "mcp"
if i > 0 {
label = ""
}
// Show MCP name and command
mcpDisplay := mcp.Name
if mcp.Command != "" {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
mcpDisplay += " (" + cmdLine + ")"
}
rows = append(rows, []string{"", label, mcpDisplay})
}
}
return
})
}
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
@@ -1219,6 +1315,11 @@ type runOptions struct {
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
IsAgent bool
AgentType string
Skills []api.SkillRef
MCPs []api.MCPRef
Entrypoint string
}
func (r runOptions) Copy() runOptions {
@@ -1248,6 +1349,12 @@ func (r runOptions) Copy() runOptions {
think = &cThink
}
var skills []api.SkillRef
if r.Skills != nil {
skills = make([]api.SkillRef, len(r.Skills))
copy(skills, r.Skills)
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
@@ -1263,6 +1370,9 @@ func (r runOptions) Copy() runOptions {
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
IsAgent: r.IsAgent,
AgentType: r.AgentType,
Skills: skills,
}
}
@@ -1346,6 +1456,65 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return nil, err
}
// Load skills for agents
var skillsCatalog *skillCatalog
if opts.IsAgent && len(opts.Skills) > 0 {
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
if err != nil {
return nil, fmt.Errorf("failed to load skills: %w", err)
}
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
var skillNames []string
for _, s := range skillsCatalog.Skills {
skillNames = append(skillNames, s.Name)
}
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
}
}
// Load MCP servers for agents (from opts and global config)
var mcpMgr *mcpManager
allMCPs := opts.MCPs
// Load global MCPs from ~/.ollama/mcp.json
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
for name, srv := range globalConfig.MCPServers {
// Skip disabled MCPs
if srv.Disabled {
continue
}
// Check if already in opts.MCPs (model takes precedence)
found := false
for _, m := range opts.MCPs {
if m.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
}
}
if len(allMCPs) > 0 {
mcpMgr = newMCPManager()
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
}
if mcpMgr.ToolCount() > 0 {
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
}
defer mcpMgr.Shutdown()
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
@@ -1369,6 +1538,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var fullResponse strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
role := "assistant"
@@ -1409,7 +1579,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
if response.Message.ToolCalls != nil {
toolCalls := response.Message.ToolCalls
if len(toolCalls) > 0 {
fmt.Print(renderToolCalls(toolCalls, false))
if skillsCatalog != nil || mcpMgr != nil {
// Store tool calls for execution after response is complete
pendingToolCalls = append(pendingToolCalls, toolCalls...)
} else {
// No skills catalog or MCP, just display tool calls
fmt.Print(renderToolCalls(toolCalls, false))
}
}
}
@@ -1422,31 +1598,161 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
opts.Format = `"` + opts.Format + `"`
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
// Prepare messages with agent-specific system prompt
messages := opts.Messages
if skillsCatalog != nil {
// Add skills system prompt as the first system message
skillsPrompt := skillsCatalog.SystemPrompt()
if skillsPrompt != "" {
// Insert skills prompt at the beginning, or append to existing system message
if len(messages) > 0 && messages[0].Role == "system" {
// Append to existing system message
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
} else {
// Insert new system message at the beginning
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
messages = append([]api.Message{systemMsg}, messages...)
}
}
}
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
// Add tools for agents (combine skills and MCP tools)
var allTools api.Tools
if skillsCatalog != nil {
allTools = append(allTools, skillsCatalog.Tools()...)
}
return nil, err
if mcpMgr != nil {
allTools = append(allTools, mcpMgr.Tools()...)
}
if len(allTools) > 0 {
req.Tools = allTools
}
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err
}
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
break
}
// Execute tool calls and continue the conversation
fmt.Fprintf(os.Stderr, "\n")
// Add assistant's tool call message to history (include thinking for proper rendering)
assistantMsg := api.Message{
Role: "assistant",
Content: fullResponse.String(),
Thinking: thinkingContent.String(),
ToolCalls: pendingToolCalls,
}
messages = append(messages, assistantMsg)
// Execute each tool call and collect results
var toolResults []api.Message
for _, call := range pendingToolCalls {
// Show what's being executed
switch call.Function.Name {
case "run_skill_script":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
commandVal, _ := call.Function.Arguments.Get("command")
command, _ := commandVal.(string)
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
case "read_skill_file":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
pathVal, _ := call.Function.Arguments.Get("path")
path, _ := pathVal.(string)
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
default:
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
}
var result api.Message
var handled bool
var err error
// Try skill catalog first
if skillsCatalog != nil {
result, handled, err = skillsCatalog.RunToolCall(call)
}
// If not handled by skills, try MCP
if !handled && mcpMgr != nil {
result, handled, err = mcpMgr.RunToolCall(call)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
// Add error result
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
})
continue
}
if !handled {
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
})
continue
}
// Display tool output
if result.Content != "" {
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
}
// Add tool result to messages (preserves ToolName, ToolCallID from result)
toolResults = append(toolResults, result)
}
// Add tool results to message history
messages = append(messages, toolResults...)
fmt.Fprintf(os.Stderr, "\n")
// Reset state for next iteration
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
// Start new progress spinner for next API call
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
}
if len(opts.Messages) > 0 {
@@ -1785,10 +2091,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -1943,6 +2245,8 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
NewSkillCommand(),
NewMCPCommand(),
)
return rootCmd

View File

@@ -34,6 +34,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /skills Show available skills")
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
@@ -444,6 +447,411 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usageShow()
}
case strings.HasPrefix(line, "/skill "):
args := strings.Fields(line)
if len(args) < 2 {
fmt.Fprintln(os.Stderr, "Usage:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current skills")
continue
}
switch args[1] {
case "add":
if len(args) < 3 {
fmt.Println("Usage: /skill add <path>")
continue
}
skillPath := args[2]
// Expand ~ to home directory
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Error expanding path: %v\n", err)
continue
}
skillPath = filepath.Join(home, skillPath[1:])
}
// Make absolute
absPath, err := filepath.Abs(skillPath)
if err != nil {
fmt.Printf("Error resolving path: %v\n", err)
continue
}
// Verify SKILL.md exists
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
continue
}
// Extract skill name from SKILL.md
content, err := os.ReadFile(skillMdPath)
if err != nil {
fmt.Printf("Error reading SKILL.md: %v\n", err)
continue
}
skillName, _ := extractSkillMetadata(string(content))
if skillName == "" {
skillName = filepath.Base(absPath)
}
// Check if already added
for _, s := range opts.Skills {
if s.Name == skillName {
fmt.Printf("Skill '%s' is already loaded\n", skillName)
continue
}
}
// Add to skills (using path as Name, no digest for local skills)
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
opts.IsAgent = true // Enable agent mode if not already
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /skill remove <name>")
continue
}
skillName := args[2]
found := false
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
for _, s := range opts.Skills {
// Match by name or by path basename
name := s.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name)
}
if name == skillName || s.Name == skillName {
found = true
fmt.Printf("Removed skill '%s'\n", skillName)
} else {
newSkills = append(newSkills, s)
}
}
if !found {
fmt.Printf("Skill '%s' not found\n", skillName)
} else {
opts.Skills = newSkills
}
case "list", "ls":
if len(opts.Skills) == 0 {
fmt.Println("No skills loaded in this session.")
} else {
fmt.Println("Skills loaded in this session:")
for _, skill := range opts.Skills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
// For local paths, show basename
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (local: " + skill.Name + ")"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
default:
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
}
continue
case strings.HasPrefix(line, "/skills"):
// Show skills from model (bundled) + session skills
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model skills with session skills
allSkills := make([]api.SkillRef, 0)
allSkills = append(allSkills, resp.Skills...)
// Add session skills that aren't already in model skills
for _, sessionSkill := range opts.Skills {
found := false
for _, modelSkill := range resp.Skills {
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
found = true
break
}
}
if !found {
allSkills = append(allSkills, sessionSkill)
}
}
if len(allSkills) == 0 {
fmt.Println("No skills available.")
} else {
fmt.Println("Available Skills:")
for _, skill := range allSkills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (session)"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
continue
case strings.HasPrefix(line, "/mcp"):
args := strings.Fields(line)
// If just "/mcp" with no args, show all MCP servers
if len(args) == 1 {
// Show MCPs from model (bundled) + global config
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model MCPs with global config MCPs
allMCPs := make([]api.MCPRef, 0)
allMCPs = append(allMCPs, resp.MCPs...)
// Load global config
globalConfig, _ := loadMCPConfig()
globalMCPNames := make(map[string]bool)
if globalConfig != nil {
for name, srv := range globalConfig.MCPServers {
// Check if already in model MCPs
found := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
globalMCPNames[name] = true
}
}
if len(allMCPs) == 0 {
fmt.Println("No MCP servers available.")
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
} else {
fmt.Println("Available MCP Servers:")
for _, mcp := range allMCPs {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
source := ""
disabled := ""
// Check if it's from model or global config
isFromModel := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == mcp.Name {
isFromModel = true
break
}
}
if isFromModel {
source = " (model)"
} else if globalMCPNames[mcp.Name] {
source = " (global)"
// Check if disabled
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
disabled = " [disabled]"
}
}
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
}
}
fmt.Println()
continue
}
switch args[1] {
case "add":
if len(args) < 4 {
fmt.Println("Usage: /mcp add <name> <command> [args...]")
continue
}
mcpName := args[2]
mcpCommand := args[3]
mcpArgs := args[4:]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
// Check if already exists
if _, exists := config.MCPServers[mcpName]; exists {
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
}
// Add to global config
config.MCPServers[mcpName] = MCPServerConfig{
Type: "stdio",
Command: mcpCommand,
Args: mcpArgs,
}
// Save config
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
cmdLine := mcpCommand
if len(mcpArgs) > 0 {
cmdLine += " " + strings.Join(mcpArgs, " ")
}
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
fmt.Println("Note: MCP server will be started on next message.")
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /mcp remove <name>")
continue
}
mcpName := args[2]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
if _, exists := config.MCPServers[mcpName]; !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
delete(config.MCPServers, mcpName)
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
fmt.Println("Note: Changes will take effect on next message.")
case "disable":
if len(args) < 3 {
fmt.Println("Usage: /mcp disable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if srv.Disabled {
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
continue
}
srv.Disabled = true
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
case "enable":
if len(args) < 3 {
fmt.Println("Usage: /mcp enable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if !srv.Disabled {
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
continue
}
srv.Disabled = false
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
default:
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
}
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
args := strings.Fields(line)
if len(args) > 1 {
@@ -452,6 +860,20 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usageSet()
case "show", "/show":
usageShow()
case "skill", "/skill":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
fmt.Fprintln(os.Stderr, "")
case "mcp", "/mcp":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
fmt.Fprintln(os.Stderr, "")
case "shortcut", "shortcuts":
usageShortcuts()
}

570
cmd/skill_cmd.go Normal file
View File

@@ -0,0 +1,570 @@
package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
)
// SkillPushHandler handles the skill push command.
func SkillPushHandler(cmd *cobra.Command, args []string) error {
if len(args) != 2 {
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
}
name := args[0]
path := args[1]
// Expand path
if strings.HasPrefix(path, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("expanding home directory: %w", err)
}
path = filepath.Join(home, path[1:])
}
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("resolving path: %w", err)
}
// Validate skill directory
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
}
// Parse skill name (will set Kind="skill")
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Create skill layer
displayName := n.DisplayShortest()
status := fmt.Sprintf("Creating skill layer for %s", displayName)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
layer, err := server.CreateSkillLayer(absPath)
if err != nil {
return fmt.Errorf("creating skill layer: %w", err)
}
spinner.Stop()
// Create skill manifest
manifest, configLayer, err := createSkillManifest(absPath, layer)
if err != nil {
return fmt.Errorf("creating skill manifest: %w", err)
}
// Write manifest locally
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return fmt.Errorf("creating manifest directory: %w", err)
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return fmt.Errorf("marshaling manifest: %w", err)
}
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
return fmt.Errorf("writing manifest: %w", err)
}
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
// Push to registry
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
// For now, we'll use the existing push mechanism
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
req := &api.PushRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Push(context.Background(), req, fn); err != nil {
// If push fails, still show success for local creation
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
return nil
}
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
return nil
}
// SkillPullHandler handles the skill pull command.
func SkillPullHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
p := progress.NewProgress(os.Stderr)
defer p.Stop()
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
displayName := n.DisplayShortest()
req := &api.PullRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Pull(context.Background(), req, fn); err != nil {
return fmt.Errorf("pulling skill: %w", err)
}
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
return nil
}
// SkillListHandler handles the skill list command.
func SkillListHandler(cmd *cobra.Command, args []string) error {
skills, err := listLocalSkills()
if err != nil {
return fmt.Errorf("listing skills: %w", err)
}
if len(skills) == 0 {
fmt.Println("No skills installed")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
for _, skill := range skills {
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
skill.Namespace,
skill.Name,
skill.Tag,
format.HumanBytes(skill.Size),
format.HumanTime(skill.ModifiedAt, "Never"),
)
}
return w.Flush()
}
// SkillRemoveHandler handles the skill rm command.
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
}
for _, name := range args {
n := server.ParseSkillName(name)
if n.Model == "" {
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
continue
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
continue
}
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
continue
}
if err := os.Remove(manifestPath); err != nil {
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
continue
}
// Clean up empty parent directories
dir := filepath.Dir(manifestPath)
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
entries, _ := os.ReadDir(dir)
if len(entries) == 0 {
os.Remove(dir)
dir = filepath.Dir(dir)
} else {
break
}
}
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
}
return nil
}
// SkillShowHandler handles the skill show command.
func SkillShowHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
data, err := os.ReadFile(manifestPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("skill not found: %s", displayName)
}
return fmt.Errorf("reading manifest: %w", err)
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return fmt.Errorf("parsing manifest: %w", err)
}
fmt.Printf("Skill: %s\n\n", displayName)
fmt.Println("Layers:")
for _, layer := range manifest.Layers {
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
}
// Try to read and display SKILL.md content
if len(manifest.Layers) > 0 {
for _, layer := range manifest.Layers {
if layer.MediaType == server.MediaTypeSkill {
skillPath, err := server.GetSkillsPath(layer.Digest)
if err == nil {
skillMdPath := filepath.Join(skillPath, "SKILL.md")
if content, err := os.ReadFile(skillMdPath); err == nil {
fmt.Println("\nContent:")
fmt.Println(string(content))
}
}
}
}
}
return nil
}
// SkillInfo represents information about an installed skill.
type SkillInfo struct {
Namespace string
Name string
Tag string
Size int64
ModifiedAt time.Time
}
// listLocalSkills returns a list of locally installed skills.
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
// where kind is "skill".
func listLocalSkills() ([]SkillInfo, error) {
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
var skills []SkillInfo
// Walk through all registries
registries, err := os.ReadDir(manifestsPath)
if err != nil {
if os.IsNotExist(err) {
return skills, nil
}
return nil, err
}
for _, registry := range registries {
if !registry.IsDir() {
continue
}
// Walk namespaces
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
if err != nil {
continue
}
for _, namespace := range namespaces {
if !namespace.IsDir() {
continue
}
// Walk kinds looking for "skill"
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
if err != nil {
continue
}
for _, kind := range kinds {
if !kind.IsDir() {
continue
}
// Only process skill kind
if kind.Name() != server.SkillNamespace {
continue
}
// Walk skill names (model names)
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
if err != nil {
continue
}
for _, skillName := range skillNames {
if !skillName.IsDir() {
continue
}
// Walk tags
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
if err != nil {
continue
}
for _, tag := range tags {
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
fi, err := os.Stat(manifestPath)
if err != nil || fi.IsDir() {
continue
}
// Read manifest to get size
data, err := os.ReadFile(manifestPath)
if err != nil {
continue
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
continue
}
var totalSize int64
for _, layer := range manifest.Layers {
totalSize += layer.Size
}
// Build display name using model.Name
n := model.Name{
Host: registry.Name(),
Namespace: namespace.Name(),
Kind: kind.Name(),
Model: skillName.Name(),
Tag: tag.Name(),
}
skills = append(skills, SkillInfo{
Namespace: n.Namespace + "/" + n.Kind,
Name: n.Model,
Tag: n.Tag,
Size: totalSize,
ModifiedAt: fi.ModTime(),
})
}
}
}
}
}
return skills, nil
}
// createSkillManifest creates a manifest for a standalone skill.
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
// Read SKILL.md to extract metadata
skillMdPath := filepath.Join(skillDir, "SKILL.md")
content, err := os.ReadFile(skillMdPath)
if err != nil {
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
}
// Extract name and description from frontmatter
name, description := extractSkillMetadata(string(content))
if name == "" {
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
}
// Create config
config := map[string]any{
"name": name,
"description": description,
"architecture": "amd64",
"os": "linux",
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, nil, fmt.Errorf("marshaling config: %w", err)
}
// Create config layer
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, nil, fmt.Errorf("creating config layer: %w", err)
}
manifest := &server.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: configLayer,
Layers: []server.Layer{layer},
}
return manifest, &configLayer, nil
}
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
func extractSkillMetadata(content string) (name, description string) {
lines := strings.Split(content, "\n")
inFrontmatter := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "---" {
if !inFrontmatter {
inFrontmatter = true
continue
} else {
break // End of frontmatter
}
}
if inFrontmatter {
if strings.HasPrefix(trimmed, "name:") {
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
} else if strings.HasPrefix(trimmed, "description:") {
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
}
}
}
return name, description
}
// NewSkillCommand creates the skill parent command with subcommands.
func NewSkillCommand() *cobra.Command {
skillCmd := &cobra.Command{
Use: "skill",
Short: "Manage skills",
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
}
pushCmd := &cobra.Command{
Use: "push NAME[:TAG] PATH",
Short: "Push a skill to a registry",
Long: "Package a local skill directory and push it to a registry",
Args: cobra.ExactArgs(2),
PreRunE: checkServerHeartbeat,
RunE: SkillPushHandler,
}
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd := &cobra.Command{
Use: "pull NAME[:TAG]",
Short: "Pull a skill from a registry",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: SkillPullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List installed skills",
Args: cobra.NoArgs,
RunE: SkillListHandler,
}
rmCmd := &cobra.Command{
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
Aliases: []string{"remove", "delete"},
Short: "Remove a skill",
Args: cobra.MinimumNArgs(1),
RunE: SkillRemoveHandler,
}
showCmd := &cobra.Command{
Use: "show NAME[:TAG]",
Short: "Show skill details",
Args: cobra.ExactArgs(1),
RunE: SkillShowHandler,
}
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
return skillCmd
}

591
cmd/skills.go Normal file
View File

@@ -0,0 +1,591 @@
package cmd
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/server"
)
const (
skillFileName = "SKILL.md"
maxSkillDescription = 1024
maxSkillNameLength = 64
)
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
type skillMetadata struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
}
type skillDefinition struct {
Name string
Description string
Content string // Full SKILL.md content (without frontmatter)
Dir string
SkillPath string
}
type skillCatalog struct {
Skills []skillDefinition
byName map[string]skillDefinition
}
func loadSkills(paths []string) (*skillCatalog, error) {
if len(paths) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, root := range paths {
info, err := os.Stat(root)
if err != nil {
return nil, fmt.Errorf("skills directory %q: %w", root, err)
}
if !info.IsDir() {
return nil, fmt.Errorf("skills path %q is not a directory", root)
}
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
// Skills can be referenced by:
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
// - Name (local path): loaded from the filesystem (for development)
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
if len(refs) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, ref := range refs {
var skillDir string
if ref.Digest != "" {
// Load from extracted skill cache
path, err := server.GetSkillsPath(ref.Digest)
if err != nil {
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
}
// Check if skill is already extracted
skillMdPath := filepath.Join(path, skillFileName)
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
// Try to extract the skill blob
path, err = server.ExtractSkillBlob(ref.Digest)
if err != nil {
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
}
}
skillDir = path
} else if ref.Name != "" {
// Check if this is a local path or a registry reference
if !server.IsLocalSkillPath(ref.Name) {
// Registry reference without a digest - skill needs to be pulled first
// This happens when an agent references a skill that hasn't been bundled
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
}
// Local path - resolve it
skillPath := ref.Name
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("expanding home directory: %w", err)
}
skillPath = filepath.Join(home, skillPath[1:])
}
absPath, err := filepath.Abs(skillPath)
if err != nil {
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
}
// Check if this is a directory containing skills or a single skill
info, err := os.Stat(absPath)
if err != nil {
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
}
if info.IsDir() {
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
skillMdPath := filepath.Join(absPath, skillFileName)
if _, err := os.Stat(skillMdPath); err == nil {
// Direct skill directory
skillDir = absPath
} else {
// Parent directory - walk to find skill subdirectories
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillSubDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillSubDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
continue
}
} else {
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
}
} else {
// Both empty - skip
continue
}
// Parse the skill from skillDir if set
if skillDir != "" {
skillMdPath := filepath.Join(skillDir, skillFileName)
skill, err := parseSkillFile(skillMdPath, skillDir)
if err != nil {
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
continue
}
byName[skill.Name] = skill
skills = append(skills, skill)
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
rawContent, err := os.ReadFile(path)
if err != nil {
return skillDefinition{}, err
}
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
if err != nil {
return skillDefinition{}, err
}
var meta skillMetadata
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
}
if err := validateSkillMetadata(meta, skillDir); err != nil {
return skillDefinition{}, err
}
absPath, err := filepath.Abs(path)
if err != nil {
return skillDefinition{}, err
}
absDir, err := filepath.Abs(skillDir)
if err != nil {
return skillDefinition{}, err
}
return skillDefinition{
Name: meta.Name,
Description: meta.Description,
Content: bodyContent,
Dir: absDir,
SkillPath: absPath,
}, nil
}
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
scanner := bufio.NewScanner(strings.NewReader(content))
if !scanner.Scan() {
return "", "", errors.New("empty SKILL.md")
}
if strings.TrimSpace(scanner.Text()) != "---" {
return "", "", errors.New("missing YAML frontmatter")
}
var fmLines []string
foundEnd := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "---" {
foundEnd = true
break
}
fmLines = append(fmLines, line)
}
if !foundEnd {
return "", "", errors.New("frontmatter not terminated")
}
// Collect remaining content as body
var bodyLines []string
for scanner.Scan() {
bodyLines = append(bodyLines, scanner.Text())
}
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
}
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
name := strings.TrimSpace(meta.Name)
description := strings.TrimSpace(meta.Description)
switch {
case name == "":
return errors.New("missing skill name")
case len(name) > maxSkillNameLength:
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
case !skillNamePattern.MatchString(name):
return fmt.Errorf("invalid skill name %q", name)
}
if description == "" {
return errors.New("missing skill description")
}
if len(description) > maxSkillDescription {
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
}
// Skip directory name check for digest-based paths (extracted from blobs)
dirName := filepath.Base(skillDir)
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
}
return nil
}
func (c *skillCatalog) SystemPrompt() string {
if c == nil || len(c.Skills) == 0 {
return ""
}
var b strings.Builder
b.WriteString("# Skills\n\n")
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
b.WriteString("## Available Tools\n\n")
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
for _, skill := range c.Skills {
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
fmt.Fprintf(&b, "%s\n\n", skill.Content)
b.WriteString("---\n\n")
}
return b.String()
}
func (c *skillCatalog) Tools() api.Tools {
if c == nil || len(c.Skills) == 0 {
return nil
}
runScriptProps := api.NewToolPropertiesMap()
runScriptProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the script",
})
runScriptProps.Set("command", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
})
readFileProps := api.NewToolPropertiesMap()
readFileProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the file",
})
readFileProps.Set("path", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The relative path to the file within the skill directory",
})
return api.Tools{
{
Type: "function",
Function: api.ToolFunction{
Name: "run_skill_script",
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "command"},
Properties: runScriptProps,
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "read_skill_file",
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "path"},
Properties: readFileProps,
},
},
},
}
}
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
switch call.Function.Name {
case "read_skill_file":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
relPath, err := requireStringArg(call.Function.Arguments, "path")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
content, err := readSkillFile(skill.Dir, relPath)
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
return toolMessage(call, content), true, nil
case "run_skill_script":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
command, err := requireStringArg(call.Function.Arguments, "command")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
output, err := runSkillScript(skill.Dir, command)
if err != nil {
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
}
return toolMessage(call, output), true, nil
default:
return api.Message{}, false, nil
}
}
// runSkillScript executes a shell command within a skill's directory.
//
// SECURITY LIMITATIONS (TODO):
// - No sandboxing: commands run with full user permissions
// - No path validation: model can run any command, not just scripts in skill dir
// - Shell injection risk: sh -c is used, malicious input could be crafted
// - No executable allowlist: any program can be called (curl, rm, etc.)
// - No environment isolation: scripts inherit full environment variables
//
// POTENTIAL IMPROVEMENTS:
// - Restrict commands to only reference files within skill directory
// - Allowlist specific executables (python3, node, bash)
// - Use sandboxing (Docker, nsjail, seccomp)
// - Require explicit script registration in SKILL.md frontmatter
// - Add per-skill configurable timeouts
func runSkillScript(skillDir, command string) (string, error) {
// Validate the skill directory exists
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
if _, err := os.Stat(absSkillDir); err != nil {
return "", fmt.Errorf("skill directory not found: %w", err)
}
// Create command with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = absSkillDir
// Inject the current working directory (where ollama run was called from)
// as an environment variable so scripts can reference files in that directory
workingDir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get working directory: %w", err)
}
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
// Capture both stdout and stderr
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
// Combine output
output := stdout.String()
if stderr.Len() > 0 {
if output != "" {
output += "\n"
}
output += stderr.String()
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("command timed out after 30 seconds")
}
return output, err
}
return output, nil
}
func readSkillFile(skillDir, relPath string) (string, error) {
relPath = filepath.Clean(strings.TrimSpace(relPath))
if relPath == "" {
return "", errors.New("path is required")
}
if filepath.IsAbs(relPath) {
return "", errors.New("path must be relative to the skill directory")
}
target := filepath.Join(skillDir, relPath)
absTarget, err := filepath.Abs(target)
if err != nil {
return "", err
}
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absSkillDir, absTarget)
if err != nil {
return "", err
}
if strings.HasPrefix(rel, "..") {
return "", errors.New("path escapes the skill directory")
}
content, err := os.ReadFile(absTarget)
if err != nil {
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
}
return string(content), nil
}
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
value, ok := args.Get(name)
if !ok {
return "", fmt.Errorf("missing required argument %q", name)
}
str, ok := value.(string)
if !ok {
return "", fmt.Errorf("argument %q must be a string", name)
}
if strings.TrimSpace(str) == "" {
return "", fmt.Errorf("argument %q cannot be empty", name)
}
return str, nil
}
func toolMessage(call api.ToolCall, content string) api.Message {
msg := api.Message{
Role: "tool",
Content: content,
ToolName: call.Function.Name,
}
if call.ID != "" {
msg.ToolCallID = call.ID
}
return msg
}

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
}
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,6 +3,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -53,7 +55,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -1,406 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -32,9 +32,7 @@
"codeblocks": "system"
},
"contextual": {
"options": [
"copy"
]
"options": ["copy"]
},
"navbar": {
"links": [
@@ -54,9 +52,7 @@
"display": "simple"
},
"examples": {
"languages": [
"curl"
]
"languages": ["curl"]
}
},
"redirects": [
@@ -101,7 +97,6 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -144,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -1,69 +0,0 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

548
docs/skills.md Normal file
View File

@@ -0,0 +1,548 @@
# Ollama Skills
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
## Quick Start
### Creating a Skill
Create a directory with a `SKILL.md` file:
```
my-skill/
├── SKILL.md # Required: Instructions for the agent
└── scripts/ # Optional: Executable scripts
└── run.py
```
The `SKILL.md` file must have YAML frontmatter:
```markdown
---
name: my-skill
description: A brief description of what this skill does
---
# My Skill
## Purpose
Explain what this skill does and when to use it.
## Instructions
Step-by-step instructions for the agent on how to use this skill.
## Examples
Show example inputs and expected outputs.
```
### Using Skills in an Agent
Reference skills in your Agentfile:
```dockerfile
FROM llama3.2:3b
AGENT_TYPE conversational
# Local skill (bundled with agent)
SKILL ./path/to/my-skill
# Registry skill (pulled from ollama.com)
SKILL library/skill/calculator:1.0.0
# User skill from registry
SKILL myname/skill/calculator:1.0.0
SYSTEM You are a helpful assistant.
```
### Managing Skills
```bash
# Push a skill to the registry (uses your namespace)
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
# Pull a skill from the official library
ollama skill pull skill/calculator:1.0.0
# Pull a skill from a user's namespace
ollama skill pull myname/skill/calculator:1.0.0
# List installed skills
ollama skill list
# Show skill details
ollama skill show skill/calculator:1.0.0
# Remove a skill
ollama skill rm skill/calculator:1.0.0
```
### Dynamic Skills in Chat
You can add and remove skills dynamically during an interactive chat session:
```
>>> /skills
Available Skills:
calculator (sha256:abc123def456...)
>>> /skill add ./my-local-skill
Added skill 'my-skill' from ./my-local-skill
>>> /skill list
Skills loaded in this session:
my-skill (local: /path/to/my-local-skill)
>>> /skill remove my-skill
Removed skill 'my-skill'
```
| Command | Description |
|---------|-------------|
| `/skills` | Show all available skills (model + session) |
| `/skill add <path>` | Add a skill from a local path |
| `/skill remove <name>` | Remove a skill by name |
| `/skill list` | List skills loaded in this session |
Dynamic skills take effect on the next message. This is useful for:
- Testing skills during development
- Temporarily adding capabilities to a model
- Experimenting with skill combinations
## Skill Reference Formats
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
| Format | Example | Description |
|--------|---------|-------------|
| Local path | `./skills/calc` | Bundled with agent at create time |
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
The `kind` field distinguishes skills from models:
- `skill` - Skill packages
- `agent` - Agent packages (future)
- (empty) - Regular models
## SKILL.md Structure
### Required Frontmatter
```yaml
---
name: skill-name # Must match directory name
description: Brief description of the skill
---
```
### Recommended Sections
1. **Purpose**: What the skill does and when to use it
2. **When to use**: Trigger conditions for the agent
3. **Instructions**: Step-by-step usage guide
4. **Examples**: Input/output examples
5. **Scripts**: Documentation for any bundled scripts
### Example: Calculator Skill
```markdown
---
name: calculator
description: Performs mathematical calculations using Python
---
# Calculator Skill
## Purpose
This skill performs mathematical calculations using a bundled Python script.
## When to use
- User asks to calculate something
- User wants to do math operations
- Any arithmetic is needed
## Instructions
1. When calculation is needed, use the `run_skill_script` tool
2. Call: `python3 scripts/calculate.py "<expression>"`
3. Return the result to the user
## Examples
**Input**: "What is 25 * 4?"
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
**Output**: "25 * 4 = 100"
```
## Storage Layout
```
~/.ollama/models/
├── blobs/
│ └── sha256-<digest> # Skill tar.gz blob
├── manifests/
│ └── registry.ollama.ai/
│ └── skill/ # Library skills
│ └── calculator/
│ └── 1.0.0
│ └── skill-username/ # User skills
│ └── my-skill/
│ └── latest
└── skills/
└── sha256-<digest>/ # Extracted skill cache
├── SKILL.md
└── scripts/
```
---
# Security Considerations
## Current State (Development)
The current implementation has several security considerations that need to be addressed before production use.
### 1. Script Execution
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
**Current behavior**:
- Scripts run with the same permissions as the Ollama process
- No sandboxing or isolation
- Full filesystem access
**Mitigations needed**:
- [ ] Sandbox script execution (containers, seccomp, etc.)
- [ ] Resource limits (CPU, memory, time)
- [ ] Filesystem isolation (read-only mounts, restricted paths)
- [ ] Network policy controls
- [ ] Capability dropping
### 2. Skill Provenance
**Risk**: Malicious skills could be pushed to the registry.
**Current behavior**:
- No code signing or verification
- No malware scanning
- Trust based on namespace ownership
**Mitigations needed**:
- [ ] Skill signing with author keys
- [ ] Registry-side malware scanning
- [ ] Content policy enforcement
- [ ] Reputation system for skill authors
### 3. Namespace Squatting
**Risk**: Malicious actors could register skill names that impersonate official tools.
**Current behavior**:
- First-come-first-served namespace registration
- No verification of skill names
**Mitigations needed**:
- [ ] Reserved namespace list (official tools, common names)
- [ ] Trademark/name verification for popular skills
- [ ] Clear namespacing conventions
### 4. Supply Chain Attacks
**Risk**: Compromised skills could inject malicious code into agents.
**Current behavior**:
- Skills pulled without integrity verification beyond digest
- No dependency tracking
**Mitigations needed**:
- [ ] SBOM (Software Bill of Materials) for skills
- [ ] Dependency vulnerability scanning
- [ ] Pinned versions in Agentfiles
- [ ] Audit logging of skill usage
### 5. Data Exfiltration
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
**Current behavior**:
- Skills have access to conversation context
- Scripts can make network requests
**Mitigations needed**:
- [ ] Network egress controls
- [ ] Sensitive data detection/masking
- [ ] Audit logging of script network activity
- [ ] User consent for data access
### 6. Privilege Escalation
**Risk**: Skills could escalate privileges through script execution.
**Current behavior**:
- Scripts inherit Ollama process privileges
- No capability restrictions
**Mitigations needed**:
- [ ] Run scripts as unprivileged user
- [ ] Drop all capabilities
- [ ] Mandatory access controls (SELinux/AppArmor)
## Recommended Security Model
### Skill Trust Levels
```
┌─────────────────────────────────────────────────────────────┐
│ Level 0: Untrusted (default) │
│ - No script execution │
│ - Instructions only │
│ - Safe for any skill │
├─────────────────────────────────────────────────────────────┤
│ Level 1: Sandboxed │
│ - Scripts run in isolated container │
│ - No network access │
│ - Read-only filesystem │
│ - Resource limits enforced │
├─────────────────────────────────────────────────────────────┤
│ Level 2: Trusted │
│ - Scripts run with network access │
│ - Can write to designated directories │
│ - Requires explicit user approval │
├─────────────────────────────────────────────────────────────┤
│ Level 3: Privileged (admin only) │
│ - Full host access │
│ - System administration skills │
│ - Requires admin approval │
└─────────────────────────────────────────────────────────────┘
```
### Skill Manifest Security Fields (Future)
```yaml
---
name: my-skill
description: A skill description
security:
trust_level: sandboxed
permissions:
- network:read # Can make HTTP GET requests
- filesystem:read:/data # Can read from /data
resource_limits:
max_memory: 256MB
max_cpu_time: 30s
max_disk: 100MB
signature: sha256:abc... # Author signature
---
```
---
# Future Considerations
## Feature Roadmap
### Phase 1: Foundation (Current)
- [x] Skill bundling with agents
- [x] Local skill development
- [x] Basic CLI commands (push, pull, list, rm, show)
- [x] Registry blob storage
- [ ] Registry namespace configuration
### Phase 2: Security
- [ ] Script sandboxing
- [ ] Permission model
- [ ] Skill signing
- [ ] Audit logging
### Phase 3: Discovery
- [ ] Skill search on ollama.com
- [ ] Skill ratings and reviews
- [ ] Usage analytics
- [ ] Featured/trending skills
### Phase 4: Advanced Features
- [ ] Skill dependencies
- [ ] Skill versioning constraints
- [ ] Skill composition (skills using skills)
- [ ] Skill testing framework
## Open Questions
### 1. Skill Execution Model
**Question**: How should skills execute scripts?
Options:
- **A) In-process**: Fast but unsafe
- **B) Subprocess**: Current approach, moderate isolation
- **C) Container**: Good isolation, requires container runtime
- **D) WASM**: Portable and safe, limited capabilities
- **E) Remote execution**: Offload to secure service
### 2. Skill Versioning
**Question**: How strict should version pinning be?
Options:
- **A) Always latest**: Simple but risky
- **B) Semantic versioning**: `^1.0.0` allows minor updates
- **C) Exact pinning**: `=1.0.0` requires explicit updates
- **D) Digest pinning**: `@sha256:abc` immutable reference
### 3. Skill Permissions
**Question**: How should users grant permissions to skills?
Options:
- **A) All or nothing**: Accept all permissions or don't use
- **B) Granular consent**: Approve each permission individually
- **C) Trust levels**: Pre-defined permission bundles
- **D) Runtime prompts**: Ask when permission is first used
### 4. Skill Discovery
**Question**: How should users find skills?
Options:
- **A) Central registry only**: ollama.com/skills
- **B) Federated registries**: Multiple skill sources
- **C) Git repositories**: Pull from GitHub, etc.
- **D) All of the above**: Multiple discovery mechanisms
### 5. Skill Monetization
**Question**: Should skill authors be able to monetize?
Options:
- **A) Free only**: All skills are free and open
- **B) Paid skills**: Authors can charge for skills
- **C) Freemium**: Free tier with paid features
- **D) Donations**: Voluntary support for authors
### 6. Skill Updates
**Question**: How should skill updates be handled?
Options:
- **A) Manual**: User explicitly updates
- **B) Auto-update**: Always use latest
- **C) Notify**: Alert user to available updates
- **D) Policy-based**: Organization controls update policy
## API Considerations
### Skill Metadata API
```
GET /api/skills
GET /api/skills/:namespace/:name
GET /api/skills/:namespace/:name/versions
GET /api/skills/:namespace/:name/readme
```
### Skill Execution API
```
POST /api/skills/:namespace/:name/execute
{
"command": "python3 scripts/run.py",
"args": ["--input", "data"],
"timeout": 30
}
```
### Skill Permissions API
```
GET /api/skills/:namespace/:name/permissions
POST /api/skills/:namespace/:name/permissions/grant
DELETE /api/skills/:namespace/:name/permissions/revoke
```
## Testing Considerations
### Skill Testing Framework
```bash
# Run skill tests
ollama skill test ./my-skill
# Test with specific model
ollama skill test ./my-skill --model llama3.2:3b
# Generate test report
ollama skill test ./my-skill --report
```
### Test File Format
```yaml
# my-skill/tests/test.yaml
tests:
- name: "basic calculation"
input: "What is 2 + 2?"
expect:
contains: "4"
tool_called: "run_skill_script"
- name: "complex expression"
input: "Calculate 15% of 200"
expect:
contains: "30"
```
## Compatibility Considerations
### Minimum Ollama Version
Skills should declare minimum Ollama version:
```yaml
---
name: my-skill
requires:
ollama: ">=0.4.0"
---
```
### Model Compatibility
Skills may require specific model capabilities:
```yaml
---
name: vision-skill
requires:
capabilities:
- vision
- tools
---
```
## Migration Path
### From Local to Registry
```bash
# Develop locally
SKILL ./my-skill
# Push when ready
ollama skill push myname/my-skill:1.0.0 ./my-skill
# Update Agentfile
SKILL skill/myname/my-skill:1.0.0
```
### Version Upgrades
```bash
# Check for updates
ollama skill outdated
# Update specific skill
ollama skill update calculator:1.0.0
# Update all skills
ollama skill update --all
```

View File

@@ -148,6 +148,16 @@ func Remotes() []string {
return r
}
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
// Returns empty slice if not configured.
func Skills() []string {
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
if raw == "" {
return []string{}
}
return strings.Split(raw, ",")
}
func BoolWithDefault(k string) func(defaultValue bool) bool {
return func(defaultValue bool) bool {
if s := Var(k); s != "" {
@@ -317,6 +327,9 @@ func AsMap() map[string]EnvVar {
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
}
// Skills configuration would go here when added
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
return ret
}

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}

2
go.mod
View File

@@ -87,5 +87,5 @@ require (
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.1
)

View File

@@ -1,149 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -1,584 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
@@ -58,6 +59,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
var messages []api.Message
var licenses []string
var skills []api.SkillRef
var mcps []api.MCPRef
params := make(map[string]any)
for _, c := range f.Commands {
@@ -118,6 +121,32 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
case "skill":
skillName := c.Args
// Expand local paths relative to the Agentfile directory
if isLocalPath(skillName) {
expanded, err := expandPath(skillName, relativeDir)
if err != nil {
return nil, fmt.Errorf("expanding skill path %q: %w", skillName, err)
}
skillName = expanded
}
skills = append(skills, api.SkillRef{Name: skillName})
case "mcp":
mcpRef, err := parseMCPArg(c.Args, relativeDir)
if err != nil {
return nil, fmt.Errorf("invalid MCP: %w", err)
}
mcps = append(mcps, mcpRef)
case "agent_type":
// Handle "AGENT TYPE conversational" -> strip "TYPE " prefix
args := c.Args
if strings.HasPrefix(strings.ToLower(args), "type ") {
args = strings.TrimSpace(args[5:])
}
req.AgentType = args
case "entrypoint":
req.Entrypoint = c.Args
default:
if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
@@ -150,6 +179,12 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
if len(licenses) > 0 {
req.License = licenses
}
if len(skills) > 0 {
req.Skills = skills
}
if len(mcps) > 0 {
req.MCPs = mcps
}
return req, nil
}
@@ -333,7 +368,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type", "entrypoint":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@@ -359,7 +394,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", \"agent_type\", \"mcp\", or \"entrypoint\"")
)
type ParserError struct {
@@ -423,6 +458,9 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
switch s := strings.ToLower(b.String()); s {
case "from":
cmd.Name = "model"
case "agent":
// "AGENT TYPE" -> "agent_type", consume next word
cmd.Name = "agent_type"
case "parameter":
// transition to stateParameter which sets command name
next = stateParameter
@@ -500,6 +538,10 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
if cmd.Name == "model" {
return &f, nil
}
// Allow entrypoint-only agents without FROM
if cmd.Name == "entrypoint" {
return &f, nil
}
}
return nil, errMissingFrom
@@ -518,7 +560,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
}
case stateName:
switch {
case isAlpha(r):
case isAlpha(r), r == '_':
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
@@ -619,7 +661,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type", "agent", "mcp", "entrypoint":
return true
default:
return false
@@ -666,3 +708,79 @@ func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User
func expandPath(path, relativeDir string) (string, error) {
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
}
// parseMCPArg parses MCP command arguments.
// Supports two formats:
//
// JSON: {"name": "web-search", "command": "uv", "args": ["run", "./script.py"]}
// Simple: web-search uv run ./script.py (name, command, args...)
func parseMCPArg(args string, relativeDir string) (api.MCPRef, error) {
args = strings.TrimSpace(args)
if args == "" {
return api.MCPRef{}, errors.New("MCP requires arguments")
}
// Try JSON format first
if strings.HasPrefix(args, "{") {
var ref api.MCPRef
if err := json.Unmarshal([]byte(args), &ref); err != nil {
return api.MCPRef{}, fmt.Errorf("invalid JSON: %w", err)
}
if ref.Name == "" {
return api.MCPRef{}, errors.New("MCP name is required")
}
if ref.Command == "" {
return api.MCPRef{}, errors.New("MCP command is required")
}
if ref.Type == "" {
ref.Type = "stdio"
}
// Expand relative paths in args
for i, arg := range ref.Args {
if isLocalPath(arg) {
expanded, err := expandPath(arg, relativeDir)
if err != nil {
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
}
ref.Args[i] = expanded
}
}
return ref, nil
}
// Simple format: name command args...
parts := strings.Fields(args)
if len(parts) < 2 {
return api.MCPRef{}, errors.New("MCP requires at least name and command")
}
ref := api.MCPRef{
Name: parts[0],
Command: parts[1],
Type: "stdio",
}
if len(parts) > 2 {
ref.Args = parts[2:]
}
// Expand relative paths in args
for i, arg := range ref.Args {
if isLocalPath(arg) {
expanded, err := expandPath(arg, relativeDir)
if err != nil {
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
}
ref.Args[i] = expanded
}
}
return ref, nil
}
// isLocalPath checks if a string looks like a local filesystem path.
func isLocalPath(s string) bool {
return strings.HasPrefix(s, "/") ||
strings.HasPrefix(s, "./") ||
strings.HasPrefix(s, "../") ||
strings.HasPrefix(s, "~")
}

View File

@@ -21,7 +21,6 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -1,33 +0,0 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -18,7 +18,6 @@ const (
CharCtrlL = 12
CharEnter = 13
CharNext = 14
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
CharPrev = 16
CharBckSearch = 18
CharFwdSearch = 19

View File

@@ -3,7 +3,6 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -12,19 +11,12 @@ func Execute(args []string) error {
}
var newRunner bool
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
if args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
_build_darwin() {
for ARCH in $ARCHS; do
status "Building darwin $ARCH"
INSTALL_PREFIX=dist/darwin-$ARCH/
INSTALL_PREFIX=dist/darwin-$ARCH/
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
if [ "$ARCH" = "amd64" ]; then
status "Building darwin $ARCH dynamic backends"
BUILD_DIR=build/darwin-$ARCH
cmake -B $BUILD_DIR \
cmake -B build/darwin-$ARCH \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
-DMLX_ENGINE=ON \
-DMLX_ENABLE_X64_MAC=ON \
-DOLLAMA_RUNNER_DIR=./
cmake --build $BUILD_DIR --target ggml-cpu -j
cmake --build $BUILD_DIR --target mlx mlxc -j
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
-DOLLAMA_RUNNER_DIR=./ \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
cmake --build build/darwin-$ARCH --target ggml-cpu -j
cmake --install build/darwin-$ARCH --component CPU
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,12 +61,10 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -154,23 +131,17 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -178,7 +149,7 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then

View File

@@ -12,17 +12,6 @@ set -eu
. $(dirname $0)/env.sh
# Check for required tools
if ! command -v zstd >/dev/null 2>&1; then
echo "ERROR: zstd is required but not installed." >&2
echo "Please install zstd:" >&2
echo " - macOS: brew install zstd" >&2
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
echo " - Arch: sudo pacman -S zstd" >&2
exit 1
fi
mkdir -p dist
docker buildx build \
@@ -48,68 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
fi

View File

@@ -66,36 +66,6 @@ if [ -n "$NEEDS" ]; then
exit 1
fi
# Function to download and extract with fallback from zst to tgz
download_and_extract() {
local url_base="$1"
local dest_dir="$2"
local filename="$3"
# Check if .tar.zst is available
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
# zst file exists - check if we have zstd tool
if ! available zstd; then
error "This version requires zstd for extraction. Please install zstd and try again:
- Debian/Ubuntu: sudo apt-get install zstd
- RHEL/CentOS/Fedora: sudo dnf install zstd
- Arch: sudo pacman -S zstd"
fi
status "Downloading ${filename}.tar.zst"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
return 0
fi
# Fall back to .tgz for older versions
status "Downloading ${filename}.tgz"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "${dest_dir}"
}
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
@@ -108,7 +78,10 @@ fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
@@ -118,9 +91,15 @@ fi
# Check for NVIDIA JetPack systems with additional downloads
if [ -f /etc/nv_tegra_release ] ; then
if grep R36 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
status "Downloading JetPack 6 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
status "Downloading JetPack 5 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
else
warning "Unsupported JetPack version detected. GPU may not be supported"
fi
@@ -243,7 +222,10 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
status "Downloading Linux ROCm ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
install_success
status "AMD GPU ready."

View File

@@ -26,7 +26,6 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -63,6 +62,10 @@ func (s *Server) CreateHandler(c *gin.Context) {
config.Renderer = r.Renderer
config.Parser = r.Parser
config.Requires = r.Requires
config.Skills = r.Skills
config.MCPs = r.MCPs
config.AgentType = r.AgentType
config.Entrypoint = r.Entrypoint
for v := range r.Files {
if !fs.ValidPath(v) {
@@ -122,7 +125,10 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()}
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
// Inherit config from base model (Renderer, Parser, Requires, Capabilities, etc.)
// This is especially important for cloud models which don't have GGUF files
// to detect capabilities from.
if err == nil && !remote {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
@@ -139,6 +145,29 @@ func (s *Server) CreateHandler(c *gin.Context) {
if config.Requires == "" {
config.Requires = baseConfig.Requires
}
// Inherit capabilities for cloud/remote models
// (local models detect capabilities from GGUF file)
if len(config.Capabilities) == 0 && len(baseConfig.Capabilities) > 0 {
config.Capabilities = baseConfig.Capabilities
}
// Inherit remote host/model if base is a cloud model
if config.RemoteHost == "" && baseConfig.RemoteHost != "" {
config.RemoteHost = baseConfig.RemoteHost
}
if config.RemoteModel == "" && baseConfig.RemoteModel != "" {
config.RemoteModel = baseConfig.RemoteModel
}
// Inherit model family for proper rendering
if config.ModelFamily == "" && baseConfig.ModelFamily != "" {
config.ModelFamily = baseConfig.ModelFamily
}
if len(config.ModelFamilies) == 0 && len(baseConfig.ModelFamilies) > 0 {
config.ModelFamilies = baseConfig.ModelFamilies
}
// Inherit context length for cloud models
if config.ContextLen == 0 && baseConfig.ContextLen > 0 {
config.ContextLen = baseConfig.ContextLen
}
}
cfgFile.Close()
}
@@ -158,6 +187,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()}
return
}
} else if r.Entrypoint != "" {
// Entrypoint-only agent: no base model needed
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
} else {
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
return
@@ -455,7 +487,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil
@@ -544,6 +576,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err
}
// Handle skill layers for agents
layers, config.Skills, err = setSkillLayers(layers, config.Skills, fn)
if err != nil {
return err
}
// Handle MCP layers for agents
layers, config.MCPs, err = setMCPLayers(layers, config.MCPs, fn)
if err != nil {
return err
}
configLayer, err := createConfigLayer(layers, *config)
if err != nil {
return err
@@ -794,6 +838,135 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
// setSkillLayers creates skill layers for local skill paths and updates the skill refs.
// Local paths are converted to bundled skill layers with digests.
// Registry references are kept as-is for later resolution during pull.
func setSkillLayers(layers []Layer, skills []model.SkillRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.SkillRef, error) {
if len(skills) == 0 {
return layers, skills, nil
}
// Remove any existing skill layers
layers = removeLayer(layers, MediaTypeSkill)
var updatedSkills []model.SkillRef
for _, skill := range skills {
// Check if this is a local path
if IsLocalSkillPath(skill.Name) {
// Expand home directory if needed
skillPath := skill.Name
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, nil, fmt.Errorf("expanding home directory: %w", err)
}
skillPath = filepath.Join(home, skillPath[1:])
}
// Make absolute
absPath, err := filepath.Abs(skillPath)
if err != nil {
return nil, nil, fmt.Errorf("resolving skill path %q: %w", skill.Name, err)
}
// Check if this is a direct skill directory or a parent containing skills
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err == nil {
// Direct skill directory
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", filepath.Base(absPath))})
layer, err := CreateSkillLayer(absPath)
if err != nil {
return nil, nil, fmt.Errorf("creating skill layer for %q: %w", skill.Name, err)
}
layers = append(layers, layer)
updatedSkills = append(updatedSkills, model.SkillRef{
Name: filepath.Base(absPath),
Digest: layer.Digest,
})
} else {
// Parent directory - walk to find skill subdirectories
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != "SKILL.md" {
return nil
}
skillDir := filepath.Dir(path)
skillName := filepath.Base(skillDir)
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", skillName)})
layer, err := CreateSkillLayer(skillDir)
if err != nil {
return fmt.Errorf("creating skill layer for %q: %w", skillDir, err)
}
layers = append(layers, layer)
updatedSkills = append(updatedSkills, model.SkillRef{
Name: skillName,
Digest: layer.Digest,
})
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("walking skill directory %q: %w", skill.Name, err)
}
}
} else if skill.Digest != "" {
// Already has a digest (from a pulled agent), keep as-is
updatedSkills = append(updatedSkills, skill)
} else {
// Registry reference - keep as-is for later resolution
updatedSkills = append(updatedSkills, skill)
}
}
return layers, updatedSkills, nil
}
// setMCPLayers handles MCP server references.
// Currently, MCPs are stored as config data (command/args).
// Future: support bundling MCP server directories as layers.
func setMCPLayers(layers []Layer, mcps []model.MCPRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.MCPRef, error) {
if len(mcps) == 0 {
return layers, mcps, nil
}
// Remove any existing MCP layers
layers = removeLayer(layers, MediaTypeMCP)
var updatedMCPs []model.MCPRef
for _, mcp := range mcps {
// Validate MCP has required fields
if mcp.Name == "" {
return nil, nil, fmt.Errorf("MCP server requires a name")
}
if mcp.Command == "" {
return nil, nil, fmt.Errorf("MCP server %q requires a command", mcp.Name)
}
// Set default type if not specified
if mcp.Type == "" {
mcp.Type = "stdio"
}
// For now, just keep MCPs as config data
// Future: detect local paths in args and bundle them
updatedMCPs = append(updatedMCPs, mcp)
fn(api.ProgressResponse{Status: fmt.Sprintf("configuring MCP: %s", mcp.Name)})
}
return layers, updatedMCPs, nil
}
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -74,11 +73,6 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -238,6 +232,13 @@ func (m *Model) String() string {
})
}
if m.Config.Entrypoint != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "entrypoint",
Args: m.Config.Entrypoint,
})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
@@ -561,24 +562,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -644,15 +627,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -667,6 +641,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -675,11 +650,13 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -687,10 +664,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
// Extract skill layers to the skills cache
for _, layer := range manifest.Layers {
if layer.MediaType == MediaTypeSkill {
fn(api.ProgressResponse{Status: fmt.Sprintf("extracting skill %s", layer.Digest)})
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
return fmt.Errorf("extracting skill layer %s: %w", layer.Digest, err)
}
}
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
@@ -725,148 +707,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)

View File

@@ -47,15 +47,6 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,14 +13,9 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -129,11 +129,30 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
// Find both 4-part (models) and 5-part (skills/agents) manifest paths
matches4, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
matches5, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*", "*"))
if err != nil {
return nil, err
}
// Combine matches, filtering to only include files
var matches []string
for _, match := range matches4 {
fi, err := os.Stat(match)
if err == nil && !fi.IsDir() {
matches = append(matches, match)
}
}
for _, match := range matches5 {
fi, err := os.Stat(match)
if err == nil && !fi.IsDir() {
matches = append(matches, match)
}
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {

315
server/mcp.go Normal file
View File

@@ -0,0 +1,315 @@
package server
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// MediaTypeMCP is the media type for MCP server layers in manifests.
const MediaTypeMCP = "application/vnd.ollama.image.mcp"
// GetMCPsPath returns the path to the extracted MCPs cache directory.
// If digest is empty, returns the mcps directory itself.
// If digest is provided, returns the path to the extracted MCP for that digest.
func GetMCPsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "mcps", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// ExtractMCPBlob extracts an MCP tar.gz blob to the mcps cache.
// The blob is expected to be at the blobs path for the given digest.
// Returns the path to the extracted MCP directory.
func ExtractMCPBlob(digest string) (string, error) {
// Get the blob path
blobPath, err := GetBlobsPath(digest)
if err != nil {
return "", fmt.Errorf("getting blob path: %w", err)
}
// Get the extraction path
mcpPath, err := GetMCPsPath(digest)
if err != nil {
return "", fmt.Errorf("getting mcp path: %w", err)
}
// Check if already extracted (look for any file)
entries, err := os.ReadDir(mcpPath)
if err == nil && len(entries) > 0 {
return mcpPath, nil
}
// Open the blob
f, err := os.Open(blobPath)
if err != nil {
return "", fmt.Errorf("opening blob: %w", err)
}
defer f.Close()
// Create gzip reader
gzr, err := gzip.NewReader(f)
if err != nil {
return "", fmt.Errorf("creating gzip reader: %w", err)
}
defer gzr.Close()
// Create tar reader
tr := tar.NewReader(gzr)
// Create the mcp directory
if err := os.MkdirAll(mcpPath, 0o755); err != nil {
return "", fmt.Errorf("creating mcp directory: %w", err)
}
// Extract files
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("reading tar: %w", err)
}
// Clean the name and ensure it doesn't escape the target directory
name := filepath.Clean(header.Name)
if strings.HasPrefix(name, "..") {
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
}
target := filepath.Join(mcpPath, name)
// Verify the target is within mcpPath
if !strings.HasPrefix(target, filepath.Clean(mcpPath)+string(os.PathSeparator)) && target != filepath.Clean(mcpPath) {
return "", fmt.Errorf("path escapes mcp directory: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0o755); err != nil {
return "", fmt.Errorf("creating directory: %w", err)
}
case tar.TypeReg:
// Ensure parent directory exists
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return "", fmt.Errorf("creating parent directory: %w", err)
}
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", fmt.Errorf("creating file: %w", err)
}
if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close()
return "", fmt.Errorf("writing file: %w", err)
}
outFile.Close()
}
}
return mcpPath, nil
}
// CreateMCPLayer creates an MCP layer from a local directory.
// The directory can optionally contain an mcp.json or package.json file.
// Returns the created layer.
func CreateMCPLayer(mcpDir string) (Layer, error) {
// Verify directory exists
info, err := os.Stat(mcpDir)
if err != nil {
return Layer{}, fmt.Errorf("mcp directory not found: %w", err)
}
if !info.IsDir() {
return Layer{}, fmt.Errorf("mcp path is not a directory: %s", mcpDir)
}
// Create a temporary file for the tar.gz
blobsPath, err := GetBlobsPath("")
if err != nil {
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
}
tmpFile, err := os.CreateTemp(blobsPath, "mcp-*.tar.gz")
if err != nil {
return Layer{}, fmt.Errorf("creating temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
tmpFile.Close()
os.Remove(tmpPath)
}()
// Create gzip writer
gzw := gzip.NewWriter(tmpFile)
defer gzw.Close()
// Create tar writer
tw := tar.NewWriter(gzw)
defer tw.Close()
// Walk the mcp directory and add files to tar
err = filepath.Walk(mcpDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Get relative path
relPath, err := filepath.Rel(mcpDir, path)
if err != nil {
return err
}
// Skip the root directory itself
if relPath == "." {
return nil
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
return err
}
// Write file contents if it's a regular file
if !info.IsDir() {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(tw, f); err != nil {
return err
}
}
return nil
})
if err != nil {
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
}
// Close writers to flush
if err := tw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
}
if err := gzw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
}
if err := tmpFile.Close(); err != nil {
return Layer{}, fmt.Errorf("closing temp file: %w", err)
}
// Open the temp file for reading
tmpFile, err = os.Open(tmpPath)
if err != nil {
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
}
defer tmpFile.Close()
// Create the layer (this will compute the digest and move to blobs)
layer, err := NewLayer(tmpFile, MediaTypeMCP)
if err != nil {
return Layer{}, fmt.Errorf("creating layer: %w", err)
}
// Extract the mcp to the cache so it's ready to use
if _, err := ExtractMCPBlob(layer.Digest); err != nil {
return Layer{}, fmt.Errorf("extracting mcp: %w", err)
}
return layer, nil
}
// IsLocalMCPPath checks if an MCP reference looks like a local path.
// Local paths are explicitly prefixed with /, ./, ../, or ~.
func IsLocalMCPPath(name string) bool {
return strings.HasPrefix(name, "/") ||
strings.HasPrefix(name, "./") ||
strings.HasPrefix(name, "../") ||
strings.HasPrefix(name, "~")
}
// MCPNamespace is the namespace used for standalone MCPs in the registry.
const MCPNamespace = "mcp"
// IsMCPReference checks if a name refers to an MCP (has mcp/ prefix).
func IsMCPReference(name string) bool {
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
// mcp/name or mcp/name:tag
if len(parts) >= 1 && parts[0] == MCPNamespace {
return true
}
// namespace/mcp/name (e.g., myuser/mcp/websearch)
if len(parts) >= 2 && parts[1] == MCPNamespace {
return true
}
return false
}
// ParseMCPName parses an MCP reference string into a model.Name.
// The Kind field is set to "mcp".
func ParseMCPName(name string) model.Name {
n := model.ParseName(name)
// If Kind wasn't set (old format without mcp/), set it
if n.Kind == "" {
n.Kind = MCPNamespace
}
return n
}
// GetMCPManifestPath returns the path to the MCP manifest file.
func GetMCPManifestPath(n model.Name) (string, error) {
if n.Model == "" {
return "", fmt.Errorf("mcp name is required")
}
// Ensure Kind is set
if n.Kind == "" {
n.Kind = MCPNamespace
}
path := filepath.Join(
envconfig.Models(),
"manifests",
n.Filepath(),
)
return path, nil
}

View File

@@ -18,6 +18,7 @@ type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Kind string // Optional: "skill", "agent", or empty for models
Repository string
Tag string
}
@@ -42,6 +43,7 @@ func ParseModelPath(name string) ModelPath {
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Kind: "",
Repository: "",
Tag: DefaultTag,
}
@@ -55,13 +57,41 @@ func ParseModelPath(name string) ModelPath {
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
case 4:
// host/namespace/kind/model or host/namespace/model:tag with kind
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
if model.ValidKinds[parts[2]] {
mp.Kind = parts[2]
mp.Repository = parts[3]
} else {
// Not a valid kind, treat as old format with extra part
mp.Repository = parts[3]
}
case 3:
// Could be: host/namespace/model OR namespace/kind/model
if model.ValidKinds[parts[1]] {
// namespace/kind/model
mp.Namespace = parts[0]
mp.Kind = parts[1]
mp.Repository = parts[2]
} else {
// host/namespace/model
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
}
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
// Could be: namespace/model OR kind/model
if model.ValidKinds[parts[0]] {
// kind/model (library skill)
mp.Kind = parts[0]
mp.Repository = parts[1]
} else {
// namespace/model
mp.Namespace = parts[0]
mp.Repository = parts[1]
}
case 1:
mp.Repository = parts[0]
}
@@ -75,20 +105,35 @@ func ParseModelPath(name string) ModelPath {
}
func (mp ModelPath) GetNamespaceRepository() string {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s", mp.Namespace, mp.Kind, mp.Repository)
}
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
if mp.Kind != "" {
return fmt.Sprintf("%s/%s:%s", mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s:%s", mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
if mp.Kind != "" {
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
@@ -97,6 +142,7 @@ func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Kind: mp.Kind,
Model: mp.Repository,
Tag: mp.Tag,
}

View File

@@ -50,8 +50,6 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -164,29 +162,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -214,12 +189,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -1009,6 +978,9 @@ func getExistingName(n model.Name) (model.Name, error) {
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Kind == "" && strings.EqualFold(e.Kind, n.Kind) {
n.Kind = e.Kind
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
@@ -1124,15 +1096,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1156,6 +1119,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
Skills: m.Config.Skills,
MCPs: m.Config.MCPs,
AgentType: m.Config.AgentType,
Entrypoint: m.Config.Entrypoint,
}
if m.Config.RemoteHost != "" {
@@ -1210,12 +1177,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model
// skip loading tensor information if this is a remote model or a skill
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
// Skills don't have model weights, skip tensor loading
if m.ModelPath == "" {
return resp, nil
}
@@ -1588,12 +1556,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -22,7 +22,6 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -42,7 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -21,7 +21,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -195,14 +194,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,48 +543,6 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -6,7 +6,6 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -17,7 +16,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -806,61 +804,3 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}

326
server/skill.go Normal file
View File

@@ -0,0 +1,326 @@
package server
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// MediaTypeSkill is the media type for skill layers in manifests.
const MediaTypeSkill = "application/vnd.ollama.image.skill"
// GetSkillsPath returns the path to the extracted skills cache directory.
// If digest is empty, returns the skills directory itself.
// If digest is provided, returns the path to the extracted skill for that digest.
func GetSkillsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "skills", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// ExtractSkillBlob extracts a skill tar.gz blob to the skills cache.
// The blob is expected to be at the blobs path for the given digest.
// Returns the path to the extracted skill directory.
func ExtractSkillBlob(digest string) (string, error) {
// Get the blob path
blobPath, err := GetBlobsPath(digest)
if err != nil {
return "", fmt.Errorf("getting blob path: %w", err)
}
// Get the extraction path
skillPath, err := GetSkillsPath(digest)
if err != nil {
return "", fmt.Errorf("getting skill path: %w", err)
}
// Check if already extracted
if _, err := os.Stat(filepath.Join(skillPath, "SKILL.md")); err == nil {
return skillPath, nil
}
// Open the blob
f, err := os.Open(blobPath)
if err != nil {
return "", fmt.Errorf("opening blob: %w", err)
}
defer f.Close()
// Create gzip reader
gzr, err := gzip.NewReader(f)
if err != nil {
return "", fmt.Errorf("creating gzip reader: %w", err)
}
defer gzr.Close()
// Create tar reader
tr := tar.NewReader(gzr)
// Create the skill directory
if err := os.MkdirAll(skillPath, 0o755); err != nil {
return "", fmt.Errorf("creating skill directory: %w", err)
}
// Extract files
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("reading tar: %w", err)
}
// Clean the name and ensure it doesn't escape the target directory
name := filepath.Clean(header.Name)
if strings.HasPrefix(name, "..") {
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
}
target := filepath.Join(skillPath, name)
// Verify the target is within skillPath
if !strings.HasPrefix(target, filepath.Clean(skillPath)+string(os.PathSeparator)) && target != filepath.Clean(skillPath) {
return "", fmt.Errorf("path escapes skill directory: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0o755); err != nil {
return "", fmt.Errorf("creating directory: %w", err)
}
case tar.TypeReg:
// Ensure parent directory exists
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return "", fmt.Errorf("creating parent directory: %w", err)
}
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", fmt.Errorf("creating file: %w", err)
}
if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close()
return "", fmt.Errorf("writing file: %w", err)
}
outFile.Close()
}
}
return skillPath, nil
}
// CreateSkillLayer creates a skill layer from a local directory.
// The directory must contain a SKILL.md file.
// Returns the created layer.
func CreateSkillLayer(skillDir string) (Layer, error) {
// Verify SKILL.md exists
skillMdPath := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
return Layer{}, fmt.Errorf("skill directory must contain SKILL.md: %w", err)
}
// Create a temporary file for the tar.gz
blobsPath, err := GetBlobsPath("")
if err != nil {
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
}
tmpFile, err := os.CreateTemp(blobsPath, "skill-*.tar.gz")
if err != nil {
return Layer{}, fmt.Errorf("creating temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
tmpFile.Close()
os.Remove(tmpPath)
}()
// Create gzip writer
gzw := gzip.NewWriter(tmpFile)
defer gzw.Close()
// Create tar writer
tw := tar.NewWriter(gzw)
defer tw.Close()
// Walk the skill directory and add files to tar
err = filepath.Walk(skillDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Get relative path
relPath, err := filepath.Rel(skillDir, path)
if err != nil {
return err
}
// Skip the root directory itself
if relPath == "." {
return nil
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
return err
}
// Write file contents if it's a regular file
if !info.IsDir() {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(tw, f); err != nil {
return err
}
}
return nil
})
if err != nil {
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
}
// Close writers to flush
if err := tw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
}
if err := gzw.Close(); err != nil {
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
}
if err := tmpFile.Close(); err != nil {
return Layer{}, fmt.Errorf("closing temp file: %w", err)
}
// Open the temp file for reading
tmpFile, err = os.Open(tmpPath)
if err != nil {
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
}
defer tmpFile.Close()
// Create the layer (this will compute the digest and move to blobs)
layer, err := NewLayer(tmpFile, MediaTypeSkill)
if err != nil {
return Layer{}, fmt.Errorf("creating layer: %w", err)
}
// Extract the skill to the cache so it's ready to use
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
return Layer{}, fmt.Errorf("extracting skill: %w", err)
}
return layer, nil
}
// IsLocalSkillPath checks if a skill reference looks like a local path.
// Local paths are explicitly prefixed with /, ./, ../, or ~.
// Registry references like "skill/calculator:1.0.0" should NOT be treated as local paths.
func IsLocalSkillPath(name string) bool {
// Local paths are explicitly indicated by path prefixes
return strings.HasPrefix(name, "/") ||
strings.HasPrefix(name, "./") ||
strings.HasPrefix(name, "../") ||
strings.HasPrefix(name, "~")
}
// SkillNamespace is the namespace used for standalone skills in the registry.
const SkillNamespace = "skill"
// IsSkillReference checks if a name refers to a skill (has skill/ prefix).
func IsSkillReference(name string) bool {
// Check for skill/ prefix (handles both "skill/foo" and "registry/skill/foo")
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
// skill/name or skill/name:tag
if len(parts) >= 1 && parts[0] == SkillNamespace {
return true
}
// namespace/skill/name (e.g., myuser/skill/calc) - not a skill ref
// registry/skill/name (e.g., registry.ollama.ai/skill/calc)
if len(parts) >= 2 && parts[1] == SkillNamespace {
return true
}
return false
}
// ParseSkillName parses a skill reference string into a model.Name.
// The Kind field is set to "skill".
// Examples:
// - "calculator" -> library/skill/calculator:latest
// - "myname/calculator" -> myname/skill/calculator:latest
// - "myname/skill/calculator:1.0.0" -> myname/skill/calculator:1.0.0
func ParseSkillName(name string) model.Name {
// Use the standard parser which now handles Kind
n := model.ParseName(name)
// If Kind wasn't set (old format without skill/), set it
if n.Kind == "" {
n.Kind = SkillNamespace
}
return n
}
// SkillDisplayName returns a user-friendly display name for a skill.
func SkillDisplayName(n model.Name) string {
return n.DisplayShortest()
}
// GetSkillManifestPath returns the path to the skill manifest file.
// Uses the 5-part structure: host/namespace/kind/model/tag
func GetSkillManifestPath(n model.Name) (string, error) {
if n.Model == "" {
return "", fmt.Errorf("skill name is required")
}
// Ensure Kind is set
if n.Kind == "" {
n.Kind = SkillNamespace
}
path := filepath.Join(
envconfig.Models(),
"manifests",
n.Filepath(),
)
return path, nil
}

View File

@@ -381,28 +381,6 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -418,11 +396,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -435,7 +413,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments templateArgs
Arguments map[string]any
}
// templateMessage is a template-compatible representation of api.Message
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
Properties: tool.Function.Parameters.Properties.ToMap(),
},
},
}
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
Arguments: tc.Function.Arguments.ToMap(),
},
})
}

View File

@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -3,13 +3,12 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
)
func (c Capability) String() string {

View File

@@ -1,5 +1,29 @@
package model
// SkillRef represents a reference to a skill, either by local path or by registry digest.
type SkillRef struct {
// Name is the local path (for development) or registry name (e.g., "skill/calculator:1.0.0")
Name string `json:"name,omitempty"`
// Digest is the content-addressable digest of the skill blob (e.g., "sha256:abc123...")
Digest string `json:"digest,omitempty"`
}
// MCPRef represents a reference to an MCP (Model Context Protocol) server.
type MCPRef struct {
// Name is the identifier for the MCP server (used for tool namespacing)
Name string `json:"name,omitempty"`
// Digest is the content-addressable digest of the bundled MCP server blob
Digest string `json:"digest,omitempty"`
// Command is the executable to run (e.g., "uv", "node", "python3")
Command string `json:"command,omitempty"`
// Args are the arguments to pass to the command
Args []string `json:"args,omitempty"`
// Env is optional environment variables for the MCP server
Env map[string]string `json:"env,omitempty"`
// Type is the transport type (currently only "stdio" is supported)
Type string `json:"type,omitempty"`
}
// ConfigV2 represents the configuration metadata for a model.
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
@@ -20,6 +44,12 @@ type ConfigV2 struct {
EmbedLen int `json:"embedding_length,omitempty"`
BaseName string `json:"base_name,omitempty"`
// agent-specific fields
Skills []SkillRef `json:"skills,omitempty"`
MCPs []MCPRef `json:"mcps,omitempty"`
AgentType string `json:"agent_type,omitempty"`
Entrypoint string `json:"entrypoint,omitempty"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`

View File

@@ -59,6 +59,7 @@ type partKind int
const (
kindHost partKind = iota
kindNamespace
kindKind
kindModel
kindTag
kindDigest
@@ -70,6 +71,8 @@ func (k partKind) String() string {
return "host"
case kindNamespace:
return "namespace"
case kindKind:
return "kind"
case kindModel:
return "model"
case kindTag:
@@ -89,6 +92,7 @@ func (k partKind) String() string {
type Name struct {
Host string
Namespace string
Kind string // Optional: "skill", "agent", or empty for models
Model string
Tag string
}
@@ -97,34 +101,27 @@ type Name struct {
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { kind } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { kind } "/" { model } ":" { tag }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// kind:
// pattern: "skill" | "agent" | "" (empty for models)
// length: [0, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// Most users should use [ParseName] instead, unless need to support
// different defaults than DefaultName.
@@ -136,6 +133,13 @@ func ParseName(s string) Name {
return Merge(ParseNameBare(s), DefaultName())
}
// ValidKinds are the allowed values for the Kind field
var ValidKinds = map[string]bool{
"skill": true,
"agent": true,
"mcp": true,
}
// ParseNameBare parses s as a name string and returns a Name. No merge with
// [DefaultName] is performed.
func ParseNameBare(s string) Name {
@@ -153,6 +157,30 @@ func ParseNameBare(s string) Name {
return n
}
s, n.Kind, promised = cutPromised(s, "/")
if !promised {
// Only 2 parts: namespace/model - what we parsed as Kind is actually Namespace
n.Namespace = n.Kind
n.Kind = ""
return n
}
// Check if what we parsed as Kind is actually a valid kind value
if !ValidKinds[n.Kind] {
// Not a valid kind - this is the old 3-part format: host/namespace/model
// Shift: Kind -> Namespace, s -> Host
n.Namespace = n.Kind
n.Kind = ""
scheme, host, ok := strings.Cut(s, "://")
if !ok {
host = scheme
}
n.Host = host
return n
}
// Valid kind found - continue parsing for namespace and optional host
s, n.Namespace, promised = cutPromised(s, "/")
if !promised {
n.Namespace = s
@@ -168,20 +196,32 @@ func ParseNameBare(s string) Name {
return n
}
// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
// ParseNameFromFilepath parses a 4 or 5-part filepath as a Name. The parts are
// expected to be in the form:
//
// { host } "/" { namespace } "/" { model } "/" { tag }
// { host } "/" { namespace } "/" { kind } "/" { model } "/" { tag }
func ParseNameFromFilepath(s string) (n Name) {
parts := strings.Split(s, string(filepath.Separator))
if len(parts) != 4 {
switch len(parts) {
case 4:
// Old format: host/namespace/model/tag
n.Host = parts[0]
n.Namespace = parts[1]
n.Model = parts[2]
n.Tag = parts[3]
case 5:
// New format: host/namespace/kind/model/tag
n.Host = parts[0]
n.Namespace = parts[1]
n.Kind = parts[2]
n.Model = parts[3]
n.Tag = parts[4]
default:
return Name{}
}
n.Host = parts[0]
n.Namespace = parts[1]
n.Model = parts[2]
n.Tag = parts[3]
if !n.IsFullyQualified() {
return Name{}
}
@@ -189,11 +229,12 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, and tag parts of the two names,
// Merge merges the host, namespace, kind, and tag parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Kind = cmp.Or(a.Kind, b.Kind)
a.Tag = cmp.Or(a.Tag, b.Tag)
return a
}
@@ -211,6 +252,10 @@ func (n Name) String() string {
b.WriteString(n.Namespace)
b.WriteByte('/')
}
if n.Kind != "" {
b.WriteString(n.Kind)
b.WriteByte('/')
}
b.WriteString(n.Model)
if n.Tag != "" {
b.WriteByte(':')
@@ -233,6 +278,12 @@ func (n Name) DisplayShortest() string {
sb.WriteByte('/')
}
// include kind if present
if n.Kind != "" {
sb.WriteString(n.Kind)
sb.WriteByte('/')
}
// always include model and tag
sb.WriteString(n.Model)
sb.WriteString(":")
@@ -256,18 +307,23 @@ func (n Name) IsValid() bool {
}
// IsFullyQualified returns true if all parts of the name are present and
// valid without the digest.
// valid without the digest. Kind is optional and only validated if non-empty.
func (n Name) IsFullyQualified() bool {
parts := []string{
n.Host,
n.Namespace,
n.Model,
n.Tag,
if !isValidPart(kindHost, n.Host) {
return false
}
for i, part := range parts {
if !isValidPart(partKind(i), part) {
return false
}
if !isValidPart(kindNamespace, n.Namespace) {
return false
}
// Kind is optional - only validate if present
if n.Kind != "" && !isValidPart(kindKind, n.Kind) {
return false
}
if !isValidPart(kindModel, n.Model) {
return false
}
if !isValidPart(kindTag, n.Tag) {
return false
}
return true
}
@@ -276,6 +332,7 @@ func (n Name) IsFullyQualified() bool {
// host to tag as a directory in the form:
//
// {host}/{namespace}/{model}/{tag}
// {host}/{namespace}/{kind}/{model}/{tag}
//
// It uses the system's filepath separator and ensures the path is clean.
//
@@ -285,6 +342,15 @@ func (n Name) Filepath() string {
if !n.IsFullyQualified() {
panic("illegal attempt to get filepath of invalid name")
}
if n.Kind != "" {
return filepath.Join(
n.Host,
n.Namespace,
n.Kind,
n.Model,
n.Tag,
)
}
return filepath.Join(
n.Host,
n.Namespace,
@@ -301,6 +367,7 @@ func (n Name) LogValue() slog.Value {
func (n Name) EqualFold(o Name) bool {
return strings.EqualFold(n.Host, o.Host) &&
strings.EqualFold(n.Namespace, o.Namespace) &&
strings.EqualFold(n.Kind, o.Kind) &&
strings.EqualFold(n.Model, o.Model) &&
strings.EqualFold(n.Tag, o.Tag)
}
@@ -317,6 +384,11 @@ func isValidLen(kind partKind, s string) bool {
}
func isValidPart(kind partKind, s string) bool {
// Kind must be one of the valid values
if kind == kindKind {
return ValidKinds[s]
}
if !isValidLen(kind, s) {
return false
}

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -4,7 +4,6 @@ package agent
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
@@ -33,29 +32,10 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Allow for this session",
"2. Always allow",
"3. Deny",
}
// toolDisplayNames maps internal tool names to human-readable display names.
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
}
// ToolDisplayName returns the human-readable display name for a tool.
func ToolDisplayName(toolName string) string {
if displayName, ok := toolDisplayNames[toolName]; ok {
return displayName
}
// Default: capitalize first letter and replace underscores with spaces
name := strings.ReplaceAll(toolName, "_", " ")
if len(name) > 0 {
return strings.ToUpper(name[:1]) + name[1:]
}
return toolName
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
@@ -199,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
@@ -225,8 +204,8 @@ func extractBashPrefix(command string) string {
return ""
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
@@ -236,49 +215,19 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
continue
}
// Normalize to forward slashes for consistent cross-platform matching
arg = strings.ReplaceAll(arg, "\\", "/")
// Security: reject absolute paths
if path.IsAbs(arg) {
return "" // Absolute path - don't create prefix
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
}
// Normalize the path using stdlib path.Clean (resolves . and ..)
cleaned := path.Clean(arg)
// Security: reject if cleaned path escapes to parent directory
if strings.HasPrefix(cleaned, "..") {
return "" // Path escapes - don't create prefix
}
// Security: if original had "..", verify cleaned path didn't escape to sibling
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
if strings.Contains(arg, "..") {
origBase := strings.SplitN(arg, "/", 2)[0]
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
if origBase != cleanedBase {
return "" // Path escaped to sibling directory
}
}
// Check if arg ends with / (explicit directory)
isDir := strings.HasSuffix(arg, "/")
// Get the directory part
var dir string
if isDir {
dir = cleaned
} else {
dir = path.Dir(cleaned)
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
@@ -383,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
@@ -395,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return true
}
// For bash commands, check prefix matches with hierarchical path support
// For bash commands, check prefix matches
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Check exact prefix match first
if a.prefixes[prefix] {
return true
}
// Check hierarchical match: if any stored prefix is a parent of current prefix
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
if a.matchesHierarchicalPrefix(prefix) {
return true
}
if prefix != "" && a.prefixes[prefix] {
return true
}
}
}
@@ -421,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return false
}
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
// Split prefix into command and path parts (format: "cmd:path/")
colonIdx := strings.Index(currentPrefix, ":")
if colonIdx == -1 {
return false
}
currentCmd := currentPrefix[:colonIdx]
currentPath := currentPrefix[colonIdx+1:]
for storedPrefix := range a.prefixes {
storedColonIdx := strings.Index(storedPrefix, ":")
if storedColonIdx == -1 {
continue
}
storedCmd := storedPrefix[:storedColonIdx]
storedPath := storedPrefix[storedColonIdx+1:]
// Commands must match exactly
if currentCmd != storedCmd {
continue
}
// Check if current path starts with stored path (hierarchical match)
// e.g., "tools/subdir/" starts with "tools/"
if strings.HasPrefix(currentPath, storedPath) {
return true
}
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
@@ -494,32 +399,16 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
if prefix := extractBashPrefix(cmd); prefix != "" {
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
}
}
}
isWarning = isCommandOutsideCwd(cmd)
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -544,29 +433,27 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
displayName := ToolDisplayName(toolName)
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
}
// For web search, show query and internet notice
// For web search, show query
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s", query))
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
@@ -583,28 +470,24 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
}
// Get terminal size
@@ -764,7 +647,7 @@ func wrapText(text string, maxWidth int) []string {
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
@@ -774,70 +657,86 @@ func getHintLines(state *selectorState) []string {
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
warningLines := 0
if state.isWarning {
warningLines = 2 // warning line + blank line after
warningLines = 1
}
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the selector (minimal, no box)
// renderSelectorBox renders the complete selector box
func renderSelectorBox(state *selectorState) {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// Draw warning line if needed
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
boxColor = "\033[91m" // bright red
}
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
}
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
}
// Draw tool info (plain white)
// Draw tool info
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
}
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 {
if i == 2 { // Deny option - show with reason input beside it
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line before hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw hint (dark grey)
// Draw hint (may be multiple lines)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
@@ -849,39 +748,50 @@ func renderSelectorBox(state *selectorState) {
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + numOptions
// (hint lines - 1) + 1 (bottom border) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 {
if i == 2 { // Deny option
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -895,26 +805,36 @@ func updateSelectorOptions(state *selectorState) {
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -938,10 +858,11 @@ func clearSelectorBox(state *selectorState) {
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
var input string
fmt.Scanln(&input)
@@ -984,16 +905,19 @@ func (a *ApprovalManager) AllowedTools() []string {
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var label string
displayName := ToolDisplayName(toolName)
var status string
var icon string
switch result.Decision {
case ApprovalOnce:
label = "Approved"
status = "Approved"
icon = "\033[32m✓\033[0m"
case ApprovalAlways:
label = "Always allowed"
status = "Always allowed"
icon = "\033[32m✓\033[0m"
case ApprovalDeny:
label = "Denied"
status = "Denied"
icon = "\033[31m✗\033[0m"
}
// Format based on tool type
@@ -1003,7 +927,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
}
}
@@ -1013,11 +937,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
}
// FormatDenyResult returns the tool result message when a tool is denied.
@@ -1027,78 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
// Returns true for Yes, false for No.
func PromptYesNo(question string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
selected := 0 // 0 = Yes, 1 = No
options := []string{"Yes", "No"}
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "%s ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
}
}
}
renderYesNo()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return false, err
}
if n == 1 {
switch buf[0] {
case 'y', 'Y':
selected = 0
renderYesNo()
case 'n', 'N':
selected = 1
renderYesNo()
case '\r', '\n': // Enter
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
return selected == 0, nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\r\033[K")
return false, nil
case 27: // Escape - could be arrow key
// Read more bytes for arrow keys
continue
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
// Arrow keys
switch buf[2] {
case 'D': // Left
if selected > 0 {
selected--
}
renderYesNo()
case 'C': // Right
if selected < len(options)-1 {
selected++
}
renderYesNo()
}
}
}
}

View File

@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
command: "head -n 100",
expected: "",
},
// Path traversal security tests
{
name: "path traversal - parent escape",
command: "cat tools/../../etc/passwd",
expected: "", // Should NOT create a prefix - path escapes
},
{
name: "path traversal - deep escape",
command: "cat tools/a/b/../../../etc/passwd",
expected: "", // Normalizes to "../etc/passwd" - escapes
},
{
name: "path traversal - absolute path",
command: "cat /etc/passwd",
expected: "", // Absolute paths should not create prefix
},
{
name: "path with safe dotdot - normalized",
command: "cat tools/subdir/../file.go",
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
},
}
for _, tt := range tests {
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
}
}
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Path traversal attack: should NOT be allowed
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
t.Error("SECURITY: path traversal attack should NOT be allowed")
}
// Another traversal variant
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
t.Error("SECURITY: deep path traversal should NOT be allowed")
}
// Valid subdirectory access should still work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed")
}
// Safe ".." that normalizes to within allowed directory should work
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow subdirectories (hierarchical matching)
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
}
// Should allow deeply nested subdirectories
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
}
// Should still allow same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
t.Error("expected cat tools/another.go to be allowed")
}
// Should NOT allow different base directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should NOT allow different command even in subdirectory
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
}
// Should NOT allow similar but different directory name
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
am := NewApprovalManager()
// Allow with forward slashes (Unix-style)
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should work with backslashes too (Windows-style) - normalized internally
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
}
// Mixed slashes should also work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
}
}
func TestMatchesHierarchicalPrefix(t *testing.T) {
am := NewApprovalManager()
// Add prefix for "cat:tools/"
am.prefixes["cat:tools/"] = true
tests := []struct {
name string
prefix string
expected bool
}{
{
name: "exact match",
prefix: "cat:tools/",
expected: true, // exact match also passes HasPrefix - caller handles exact match first
},
{
name: "subdirectory",
prefix: "cat:tools/subdir/",
expected: true,
},
{
name: "deeply nested",
prefix: "cat:tools/a/b/c/",
expected: true,
},
{
name: "different base directory",
prefix: "cat:src/",
expected: false,
},
{
name: "different command same path",
prefix: "ls:tools/",
expected: false,
},
{
name: "similar directory name",
prefix: "cat:toolsbin/",
expected: false,
},
{
name: "invalid prefix format",
prefix: "cattools",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := am.matchesHierarchicalPrefix(tt.prefix)
if result != tt.expected {
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
tt.prefix, result, tt.expected)
}
})
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,12 +6,10 @@ import (
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -24,101 +22,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
localModelTokenLimit = 4000
// defaultTokenLimit is the token limit for cloud/remote models.
defaultTokenLimit = 10000
// charsPerToken is a rough estimate of characters per token.
// TODO: Estimate tokens more accurately using tokenizer if available
charsPerToken = 4
)
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
}
// isLocalServer checks if connecting to a local Ollama server.
// TODO: Could also check other indicators of local vs cloud server
func isLocalServer() bool {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return true // Default is localhost:11434
}
// Parse the URL to check host
parsed, err := url.Parse(host)
if err != nil {
return true // If can't parse, assume local
}
hostname := parsed.Hostname()
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
var tokenLimit int
if isLocalModel(modelName) && isLocalServer() {
tokenLimit = localModelTokenLimit
} else {
tokenLimit = defaultTokenLimit
}
maxChars := tokenLimit * charsPerToken
if len(output) > maxChars {
return output[:maxChars] + "\n... (output truncated)"
}
return output
}
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
func waitForOllamaSignin(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\n")
return ctx.Err()
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
return nil
}
// Still waiting, show dot
fmt.Fprintf(os.Stderr, ".")
}
}
}
return err
}
return nil
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -134,9 +37,6 @@ type RunOptions struct {
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
// YoloMode skips all tool approval prompts
YoloMode bool
}
// Chat runs an agent chat loop with tool support.
@@ -177,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
role := "assistant"
messages := opts.Messages
@@ -260,58 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
continue // Retry the loop
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
}
// Check for 500 errors (often tool parsing failures) - inform the model
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
consecutiveErrors++
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
if assistantContent == "" {
assistantContent = "(empty response)"
}
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
messages = append(messages,
api.Message{Role: "user", Content: errorMsg},
)
// Reset state and retry
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
continue
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
@@ -321,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, err
}
// Reset consecutive error counter on success
consecutiveErrors = 0
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
@@ -353,8 +197,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
@@ -364,21 +208,15 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check if command is auto-allowed (safe command)
// TODO(parthsareen): re-enable with tighter scoped allowlist
// if agent.IsAutoAllowed(cmd) {
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
// skipApproval = true
// }
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
}
}
// Check approval (uses prefix matching for bash commands)
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
@@ -406,30 +244,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
}
}
}
}
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
@@ -437,7 +258,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
})
continue
}
toolSuccess:
// Display tool output (truncated for display)
if toolResult != "" {
@@ -449,12 +269,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResultForLLM,
Content: toolResult,
ToolCallID: call.ID,
})
}
@@ -500,18 +317,17 @@ func truncateUTF8(s string, limit int) string {
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
displayName := agent.ToolDisplayName(toolName)
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
}
}
return displayName
return toolName
}
// Helper types and functions for display
@@ -633,8 +449,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -651,7 +466,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
}
@@ -659,17 +474,14 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
@@ -712,9 +524,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
@@ -737,7 +546,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
}
assistant, err := Chat(cmd.Context(), opts)

View File

@@ -1,180 +0,0 @@
package cmd
import (
"testing"
)
func TestIsLocalModel(t *testing.T) {
tests := []struct {
name string
modelName string
expected bool
}{
{
name: "local model without suffix",
modelName: "llama3.2",
expected: true,
},
{
name: "local model with version",
modelName: "qwen2.5:7b",
expected: true,
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
expected: false,
},
{
name: "empty model name",
modelName: "",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalModel(tt.modelName)
if result != tt.expected {
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
}
})
}
}
func TestIsLocalServer(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{
name: "empty host (default)",
host: "",
expected: true,
},
{
name: "localhost",
host: "http://localhost:11434",
expected: true,
},
{
name: "127.0.0.1",
host: "http://127.0.0.1:11434",
expected: true,
},
{
name: "custom port on localhost",
host: "http://localhost:8080",
expected: true, // localhost is always considered local
},
{
name: "remote host",
host: "http://ollama.example.com:11434",
expected: true, // has :11434
},
{
name: "remote host different port",
host: "http://ollama.example.com:8080",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := isLocalServer()
if result != tt.expected {
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
}
})
}
}
func TestTruncateToolOutput(t *testing.T) {
// Create outputs of different sizes
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
for i := range localLimitOutput {
localLimitOutput[i] = 'a'
}
for i := range defaultLimitOutput {
defaultLimitOutput[i] = 'b'
}
tests := []struct {
name string
output string
modelName string
host string
shouldTrim bool
expectedLimit int
}{
{
name: "short output local model",
output: "hello world",
modelName: "llama3.2",
host: "",
shouldTrim: false,
expectedLimit: localModelTokenLimit,
},
{
name: "long output local model - trimmed at 4k",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "",
shouldTrim: true,
expectedLimit: localModelTokenLimit,
},
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,
},
{
name: "long output remote server - uses 10k limit",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "http://remote.example.com:8080",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := truncateToolOutput(tt.output, tt.modelName)
if tt.shouldTrim {
maxLen := tt.expectedLimit * charsPerToken
if len(result) > maxLen+50 { // +50 for the truncation message
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
}
if result == tt.output {
t.Error("expected output to be truncated but it wasn't")
}
} else {
if result != tt.output {
t.Error("expected output to not be truncated")
}
}
})
}
}

38
x/imagegen/.gitignore vendored
View File

@@ -1,38 +0,0 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

View File

@@ -1,250 +0,0 @@
# Image Generation in Ollama (Experimental)
Generate images from text prompts using local AI models.
## Quick Start
```bash
# Run with a prompt
ollama run z-image "a sunset over mountains"
Generating: step 30/30
Image saved to: /tmp/ollama-image-1704067200.png
```
On macOS, the generated image will automatically open in Preview.
## Supported Models
| Model | VRAM Required | Notes |
|-------|---------------|-------|
| z-image | ~12GB | Based on Flux architecture |
## CLI Usage
```bash
# Generate an image
ollama run z-image "a cat playing piano"
# Check if model is running
ollama ps
# Stop the model
ollama stop z-image
```
## API
### OpenAI-Compatible Endpoint
```bash
POST /v1/images/generations
```
**Request:**
```json
{
"model": "z-image",
"prompt": "a sunset over mountains",
"size": "1024x1024",
"response_format": "b64_json"
}
```
**Response:**
```json
{
"created": 1704067200,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
### Example: cURL
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}'
```
### Example: Save to File
```bash
curl -s http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
```
### Streaming Progress
Enable streaming to receive progress updates via SSE:
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
```
Events:
```
event: progress
data: {"step": 1, "total": 30}
event: progress
data: {"step": 2, "total": 30}
...
event: done
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
```
## Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| model | string | required | Model name |
| prompt | string | required | Text description of image |
| size | string | "1024x1024" | Image dimensions (WxH) |
| n | int | 1 | Number of images (currently only 1 supported) |
| response_format | string | "b64_json" | "b64_json" or "url" |
| stream | bool | false | Enable progress streaming |
## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
- Sufficient VRAM (see model table above)
- Ollama built with MLX support
## Limitations
- macOS only (uses MLX backend)
- Single image per request
- Fixed step count (30 steps)
- Modelfiles not yet supported (use `ollama create` from model directory)
---
# Tensor Model Storage Format
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
## Manifest Structure
The manifest follows the standard ollama format with tensor-specific layer metadata:
```json
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:...", "size": 1234 },
"layers": [
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:25b36eed...",
"size": 49807448,
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
"dtype": "BF16",
"shape": [2560, 9728]
},
{
"mediaType": "application/vnd.ollama.image.json",
"digest": "sha256:abc123...",
"size": 512,
"name": "text_encoder/config.json"
}
]
}
```
Each tensor layer includes:
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
- `dtype`: Data type (BF16, F32, etc.)
- `shape`: Tensor dimensions
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
## Blob Format
Each tensor blob is a minimal safetensors file:
```
[8 bytes: header size (uint64 LE)]
[~80 bytes: JSON header, padded to 8-byte alignment]
[N bytes: raw tensor data]
```
Header contains a single tensor named `"data"`:
```json
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
```
## Why Include the Header?
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
1. **Uses mmap** - Maps file directly into memory, no copies
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
3. **No custom code** - Standard MLX API, battle-tested
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
## Why Per-Tensor Blobs?
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
```
~/.ollama/models/
blobs/
sha256-25b36eed... <- shared by both models
sha256-abc123...
manifests/
library/model-a/latest <- references sha256-25b36eed
library/model-b/latest <- references sha256-25b36eed
```
## Import Flow
```
cd ./weights/Z-Image-Turbo
ollama create z-image
1. Scan component directories (text_encoder/, transformer/, vae/)
2. For each .safetensors file:
- Extract individual tensors
- Wrap each in minimal safetensors format (88B header + data)
- Write to blob store (SHA256 content-addressed)
- Add layer entry to manifest with path-style name
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,231 +0,0 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

View File

@@ -1,31 +0,0 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -1,156 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,164 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}

View File

@@ -1,197 +0,0 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -1,541 +0,0 @@
// cli.go provides CLI commands for image generation models.
//
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
// Currently these are separate to keep experimental code isolated.
package imagegen
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
)
// ImageGenOptions holds options for image generation.
// These can be set via environment variables or interactive commands.
type ImageGenOptions struct {
Width int
Height int
Steps int
Seed int
NegativePrompt string
}
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 9,
Seed: 0, // 0 means random
}
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
cmd.Flags().MarkHidden("seed")
cmd.Flags().MarkHidden("negative")
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
opts.Width = v
}
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
opts.Height = v
}
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
opts.Steps = v
}
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
opts.Seed = v
}
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
opts.NegativePrompt = v
}
}
if interactive {
return runInteractive(cmd, name, keepAlive, opts)
}
// One-shot generation
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
}
// generateImageWithOptions generates an image with the given options.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
})
p.Stop()
if err != nil {
return err
}
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
return nil
}
// runInteractive runs an interactive REPL for image generation.
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
continue
case err != nil:
return err
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Handle commands
switch {
case strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts)
continue
case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
case strings.HasPrefix(line, "/show"):
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
})
p.Stop()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
}
}
// sanitizeFilename removes characters that aren't safe for filenames.
func sanitizeFilename(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
// Remove any character that's not alphanumeric or hyphen
var result strings.Builder
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
fmt.Fprintln(os.Stderr)
}
// printCurrentSettings prints the current image generation settings.
func printCurrentSettings(opts ImageGenOptions) {
fmt.Fprintf(os.Stderr, "Current settings:\n")
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
if opts.NegativePrompt != "" {
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
}
fmt.Fprintln(os.Stderr)
}
// handleSetCommand handles /set commands to change options.
func handleSetCommand(args string, opts *ImageGenOptions) error {
parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 {
return fmt.Errorf("usage: /set <option> <value>")
}
key := strings.ToLower(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "width", "w":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("width must be a positive integer")
}
opts.Width = v
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
case "height", "h":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("height must be a positive integer")
}
opts.Height = v
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
case "steps", "s":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("steps must be a positive integer")
}
opts.Steps = v
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
case "seed":
v, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("seed must be an integer")
}
opts.Seed = v
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
case "negative", "neg", "n":
opts.NegativePrompt = value
if value == "" {
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
} else {
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
}
default:
return fmt.Errorf("unknown option: %s (try /help)", key)
}
return nil
}
// displayImageInTerminal attempts to render an image inline in the terminal.
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
// Returns true if the image was displayed, false otherwise.
func displayImageInTerminal(imagePath string) bool {
// Check if terminal supports inline images
termProgram := os.Getenv("TERM_PROGRAM")
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
weztermPane := os.Getenv("WEZTERM_PANE")
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
// Read the image file
data, err := os.ReadFile(imagePath)
if err != nil {
return false
}
encoded := base64.StdEncoding.EncodeToString(data)
switch {
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
// iTerm2/WezTerm inline image protocol
// ESC ] 1337 ; File = [arguments] : base64 BEL
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
return true
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
// Kitty graphics protocol (also used by Ghostty)
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[i:end]
if i == 0 {
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
more := 1
if end >= len(encoded) {
more = 0
}
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
} else if end >= len(encoded) {
// Last chunk
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
} else {
// Middle chunk
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
}
}
fmt.Println()
return true
default:
return false
}
}

View File

@@ -1,190 +0,0 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -1,120 +0,0 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -1,18 +0,0 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -1,35 +0,0 @@
# MLX Engine
Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
## Image Generation
```bash
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
```
Options:
- `-width`, `-height` - image dimensions (default 1024x1024)
- `-steps` - denoising steps (default 9)
- `-seed` - random seed (default 42)

View File

@@ -1,359 +0,0 @@
//go:build mlx
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}

View File

@@ -1,89 +0,0 @@
//go:build mlx
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -1,293 +0,0 @@
//go:build mlx
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}

View File

@@ -1,49 +0,0 @@
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}

View File

@@ -1,213 +0,0 @@
package imagegen
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
continue
}
// Find all safetensors files in this component
entries, err := os.ReadDir(componentDir)
if err != nil {
return fmt.Errorf("failed to read %s: %w", component, err)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(componentDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
}
// Import config files
configFiles := []string{
"model_index.json",
"text_encoder/config.json",
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
}
for _, cfgPath := range configFiles {
fullPath := filepath.Join(modelDir, cfgPath)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
continue
}
fn(fmt.Sprintf("importing config %s", cfgPath))
var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
}
// Rename _class_name to architecture, remove diffusers-specific fields
if className, ok := cfg["_class_name"]; ok {
cfg["architecture"] = className
delete(cfg, "_class_name")
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
}
r = bytes.NewReader(data)
} else {
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
defer f.Close()
r = f
}
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use model_index.json as the config layer
if cfgPath == "model_index.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("model_index.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

Some files were not shown because too many files have changed in this diff Show More