mirror of
https://github.com/ollama/ollama.git
synced 2026-01-11 09:00:53 -05:00
Compare commits
11 Commits
mlx-gpu-cd
...
parth/agen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c2c2b8de9 | ||
|
|
5e23c4f2f7 | ||
|
|
5c0caaff86 | ||
|
|
e28ee8524d | ||
|
|
623e539a09 | ||
|
|
51911a5f6f | ||
|
|
2c2354e980 | ||
|
|
ce6b19d8be | ||
|
|
1de00fada0 | ||
|
|
7ecae75c4c | ||
|
|
ad5c276cf6 |
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
@@ -68,7 +68,6 @@ jobs:
|
||||
name: bundles-darwin
|
||||
path: |
|
||||
dist/*.tgz
|
||||
dist/*.tar.zst
|
||||
dist/*.zip
|
||||
dist/*.dmg
|
||||
|
||||
@@ -393,13 +392,13 @@ jobs:
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
*.tar.zst
|
||||
*.tgz
|
||||
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
@@ -532,7 +531,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||
# target architecture.
|
||||
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||
# Single architecture specified
|
||||
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -163,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,28 +83,6 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -162,21 +140,6 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
33
Dockerfile
33
Dockerfile
@@ -131,36 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -181,7 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -200,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
||||
@@ -1,778 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -1,953 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
22
api/types.go
22
api/types.go
@@ -19,6 +19,12 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillRef is an alias for model.SkillRef representing a skill reference.
|
||||
type SkillRef = model.SkillRef
|
||||
|
||||
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
|
||||
type MCPRef = model.MCPRef
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
@@ -690,6 +696,18 @@ type CreateRequest struct {
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Skills is a list of skill references for the agent (local paths or registry refs)
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
|
||||
// MCPs is a list of MCP server references for the agent
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
|
||||
// AgentType defines the type of agent (e.g., "conversational", "task-based")
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
|
||||
// Entrypoint specifies an external command to run instead of the built-in chat loop
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -741,6 +759,10 @@ type ShowResponse struct {
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
402
cmd/agent_loop_test.go
Normal file
402
cmd/agent_loop_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestToolMessage verifies that tool messages are constructed correctly
|
||||
// with ToolName and ToolCallID preserved from the tool call.
|
||||
func TestToolMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
call api.ToolCall
|
||||
content string
|
||||
expected api.Message
|
||||
}{
|
||||
{
|
||||
name: "basic tool message with ID",
|
||||
call: api.ToolCall{
|
||||
ID: "call_abc123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
content: "Sunny, 22°C",
|
||||
expected: api.Message{
|
||||
Role: "tool",
|
||||
Content: "Sunny, 22°C",
|
||||
ToolName: "get_weather",
|
||||
ToolCallID: "call_abc123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool message without ID",
|
||||
call: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"expression": "2+2",
|
||||
},
|
||||
},
|
||||
},
|
||||
content: "4",
|
||||
expected: api.Message{
|
||||
Role: "tool",
|
||||
Content: "4",
|
||||
ToolName: "calculate",
|
||||
// ToolCallID should be empty when call.ID is empty
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MCP tool message",
|
||||
call: api.ToolCall{
|
||||
ID: "call_mcp123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "mcp_websearch_search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "ollama agents",
|
||||
},
|
||||
},
|
||||
},
|
||||
content: "Found 10 results",
|
||||
expected: api.Message{
|
||||
Role: "tool",
|
||||
Content: "Found 10 results",
|
||||
ToolName: "mcp_websearch_search",
|
||||
ToolCallID: "call_mcp123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skill tool message",
|
||||
call: api.ToolCall{
|
||||
ID: "call_skill456",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "run_skill_script",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"skill": "calculator",
|
||||
"command": "python scripts/calc.py 2+2",
|
||||
},
|
||||
},
|
||||
},
|
||||
content: "Result: 4",
|
||||
expected: api.Message{
|
||||
Role: "tool",
|
||||
Content: "Result: 4",
|
||||
ToolName: "run_skill_script",
|
||||
ToolCallID: "call_skill456",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := toolMessage(tt.call, tt.content)
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("toolMessage() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAssistantMessageWithThinking verifies that assistant messages
|
||||
// in the tool loop should include thinking content.
|
||||
func TestAssistantMessageConstruction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
thinking string
|
||||
toolCalls []api.ToolCall
|
||||
expectedMsg api.Message
|
||||
}{
|
||||
{
|
||||
name: "assistant with thinking and tool calls",
|
||||
content: "",
|
||||
thinking: "I need to check the weather for Paris.",
|
||||
toolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedMsg: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
Thinking: "I need to check the weather for Paris.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "assistant with content, thinking, and tool calls",
|
||||
content: "Let me check that for you.",
|
||||
thinking: "User wants weather info.",
|
||||
toolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedMsg: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Let me check that for you.",
|
||||
Thinking: "User wants weather info.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "assistant with multiple tool calls",
|
||||
content: "",
|
||||
thinking: "I'll check both cities.",
|
||||
toolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_a",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_b",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedMsg: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
Thinking: "I'll check both cities.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_a",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_b",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the assistant message construction as done in chat()
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: tt.content,
|
||||
Thinking: tt.thinking,
|
||||
ToolCalls: tt.toolCalls,
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedMsg, assistantMsg); diff != "" {
|
||||
t.Errorf("assistant message mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMessageStitchingOrder verifies that messages in a tool loop
|
||||
// are stitched in the correct order:
|
||||
// 1. User message
|
||||
// 2. Assistant message with tool calls (and thinking)
|
||||
// 3. Tool result messages (one per tool call, in order)
|
||||
// 4. Next assistant response
|
||||
func TestMessageStitchingOrder(t *testing.T) {
|
||||
// Simulate a complete tool loop conversation
|
||||
messages := []api.Message{
|
||||
// Initial user message
|
||||
{Role: "user", Content: "What's the weather in Paris and London?"},
|
||||
// Assistant's first response with tool calls
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
Thinking: "I need to check the weather for both cities.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{ID: "call_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
||||
{ID: "call_2", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
|
||||
},
|
||||
},
|
||||
// Tool results (in order matching tool calls)
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolName: "get_weather", ToolCallID: "call_1"},
|
||||
{Role: "tool", Content: "Rainy, 15°C", ToolName: "get_weather", ToolCallID: "call_2"},
|
||||
// Final assistant response
|
||||
{Role: "assistant", Content: "Paris is sunny at 22°C, and London is rainy at 15°C.", Thinking: "Got the data, now summarizing."},
|
||||
}
|
||||
|
||||
// Verify structure
|
||||
expectedRoles := []string{"user", "assistant", "tool", "tool", "assistant"}
|
||||
for i, msg := range messages {
|
||||
if msg.Role != expectedRoles[i] {
|
||||
t.Errorf("message %d: expected role %q, got %q", i, expectedRoles[i], msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify tool results match tool calls in order
|
||||
assistantWithTools := messages[1]
|
||||
toolResults := []api.Message{messages[2], messages[3]}
|
||||
|
||||
if len(toolResults) != len(assistantWithTools.ToolCalls) {
|
||||
t.Errorf("expected %d tool results for %d tool calls", len(assistantWithTools.ToolCalls), len(toolResults))
|
||||
}
|
||||
|
||||
for i, result := range toolResults {
|
||||
expectedToolCallID := assistantWithTools.ToolCalls[i].ID
|
||||
if result.ToolCallID != expectedToolCallID {
|
||||
t.Errorf("tool result %d: expected ToolCallID %q, got %q", i, expectedToolCallID, result.ToolCallID)
|
||||
}
|
||||
expectedToolName := assistantWithTools.ToolCalls[i].Function.Name
|
||||
if result.ToolName != expectedToolName {
|
||||
t.Errorf("tool result %d: expected ToolName %q, got %q", i, expectedToolName, result.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify thinking is present in assistant messages
|
||||
if messages[1].Thinking == "" {
|
||||
t.Error("first assistant message should have thinking content")
|
||||
}
|
||||
if messages[4].Thinking == "" {
|
||||
t.Error("final assistant message should have thinking content")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiTurnToolLoop verifies message stitching across multiple
|
||||
// tool call iterations.
|
||||
func TestMultiTurnToolLoop(t *testing.T) {
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's 2+2 and also what's the weather in Paris?"},
|
||||
// First tool call: calculate
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "I'll start with the calculation.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{ID: "calc_1", Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expr": "2+2"}}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "4", ToolName: "calculate", ToolCallID: "calc_1"},
|
||||
// Second tool call: weather
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "Got the calculation. Now checking weather.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{ID: "weather_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 20°C", ToolName: "get_weather", ToolCallID: "weather_1"},
|
||||
// Final response
|
||||
{Role: "assistant", Content: "2+2 equals 4, and Paris is sunny at 20°C."},
|
||||
}
|
||||
|
||||
// Count message types
|
||||
roleCounts := map[string]int{}
|
||||
for _, msg := range messages {
|
||||
roleCounts[msg.Role]++
|
||||
}
|
||||
|
||||
if roleCounts["user"] != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", roleCounts["user"])
|
||||
}
|
||||
if roleCounts["assistant"] != 3 {
|
||||
t.Errorf("expected 3 assistant messages, got %d", roleCounts["assistant"])
|
||||
}
|
||||
if roleCounts["tool"] != 2 {
|
||||
t.Errorf("expected 2 tool messages, got %d", roleCounts["tool"])
|
||||
}
|
||||
|
||||
// Verify each tool message follows an assistant with matching tool call
|
||||
for i, msg := range messages {
|
||||
if msg.Role == "tool" {
|
||||
// Find preceding assistant message with tool calls
|
||||
var precedingAssistant *api.Message
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 {
|
||||
precedingAssistant = &messages[j]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if precedingAssistant == nil {
|
||||
t.Errorf("tool message at index %d has no preceding assistant with tool calls", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify tool result matches one of the tool calls
|
||||
found := false
|
||||
for _, tc := range precedingAssistant.ToolCalls {
|
||||
if tc.ID == msg.ToolCallID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("tool message at index %d has ToolCallID %q not found in preceding tool calls", i, msg.ToolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillCatalogRunToolCallPreservesFields tests that skill catalog
|
||||
// returns tool messages with correct fields.
|
||||
func TestSkillCatalogToolMessageFields(t *testing.T) {
|
||||
// Create a minimal test for toolMessage function
|
||||
call := api.ToolCall{
|
||||
ID: "test_id_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "run_skill_script",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"skill": "test-skill",
|
||||
"command": "echo hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msg := toolMessage(call, "hello")
|
||||
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.Content != "hello" {
|
||||
t.Errorf("expected content 'hello', got %q", msg.Content)
|
||||
}
|
||||
if msg.ToolName != "run_skill_script" {
|
||||
t.Errorf("expected ToolName 'run_skill_script', got %q", msg.ToolName)
|
||||
}
|
||||
if msg.ToolCallID != "test_id_123" {
|
||||
t.Errorf("expected ToolCallID 'test_id_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
}
|
||||
474
cmd/cmd.go
474
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -46,8 +47,6 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -98,10 +97,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
@@ -463,15 +458,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -510,6 +496,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
// Check if this is an agent
|
||||
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
|
||||
if isAgent {
|
||||
opts.IsAgent = true
|
||||
opts.AgentType = info.AgentType
|
||||
opts.Skills = info.Skills
|
||||
opts.MCPs = info.MCPs
|
||||
opts.Entrypoint = info.Entrypoint
|
||||
}
|
||||
|
||||
// Check if this is an embedding model
|
||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||
|
||||
@@ -535,7 +531,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
// If agent has entrypoint, run it instead of chat loop
|
||||
if opts.Entrypoint != "" {
|
||||
return runEntrypoint(cmd, opts)
|
||||
}
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -563,16 +562,69 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
// Use experimental agent loop with
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
// For agents, use chat API even in non-interactive mode to support tools
|
||||
if opts.IsAgent {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
|
||||
_, err := chat(cmd, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
|
||||
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
|
||||
entrypoint := opts.Entrypoint
|
||||
|
||||
// Check if entrypoint contains $PROMPT placeholder
|
||||
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
|
||||
|
||||
if hasPlaceholder && opts.Prompt != "" {
|
||||
// Replace $PROMPT with the actual prompt
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
|
||||
} else if hasPlaceholder {
|
||||
// No prompt provided but placeholder exists - remove placeholder
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
|
||||
}
|
||||
|
||||
// Parse entrypoint into command and args
|
||||
parts := strings.Fields(entrypoint)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("empty entrypoint")
|
||||
}
|
||||
|
||||
command := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
// If user provided a prompt and no placeholder was used, append it as argument
|
||||
if opts.Prompt != "" && !hasPlaceholder {
|
||||
args = append(args, opts.Prompt)
|
||||
}
|
||||
|
||||
// Look up command in PATH
|
||||
execPath, err := exec.LookPath(command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("entrypoint command not found: %s", command)
|
||||
}
|
||||
|
||||
// Create subprocess
|
||||
proc := exec.Command(execPath, args...)
|
||||
proc.Stdin = os.Stdin
|
||||
proc.Stdout = os.Stdout
|
||||
proc.Stderr = os.Stderr
|
||||
|
||||
// Run and wait
|
||||
return proc.Run()
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -837,11 +889,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -937,47 +984,96 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
// Only show Model section if there's actual model info (not for entrypoint-only agents)
|
||||
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
|
||||
if hasModelInfo {
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// Display agent information if this is an agent
|
||||
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
|
||||
tableRender("Agent", func() (rows [][]string) {
|
||||
if resp.AgentType != "" {
|
||||
rows = append(rows, []string{"", "type", resp.AgentType})
|
||||
}
|
||||
if resp.Entrypoint != "" {
|
||||
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
|
||||
}
|
||||
if len(resp.Skills) > 0 {
|
||||
for i, skill := range resp.Skills {
|
||||
label := "skill"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show skill name or digest
|
||||
skillDisplay := skill.Name
|
||||
if skillDisplay == "" && skill.Digest != "" {
|
||||
skillDisplay = skill.Digest[:12] + "..."
|
||||
}
|
||||
rows = append(rows, []string{"", label, skillDisplay})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
if len(resp.MCPs) > 0 {
|
||||
for i, mcp := range resp.MCPs {
|
||||
label := "mcp"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show MCP name and command
|
||||
mcpDisplay := mcp.Name
|
||||
if mcp.Command != "" {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
mcpDisplay += " (" + cmdLine + ")"
|
||||
}
|
||||
rows = append(rows, []string{"", label, mcpDisplay})
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
@@ -1219,6 +1315,11 @@ type runOptions struct {
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
IsAgent bool
|
||||
AgentType string
|
||||
Skills []api.SkillRef
|
||||
MCPs []api.MCPRef
|
||||
Entrypoint string
|
||||
}
|
||||
|
||||
func (r runOptions) Copy() runOptions {
|
||||
@@ -1248,6 +1349,12 @@ func (r runOptions) Copy() runOptions {
|
||||
think = &cThink
|
||||
}
|
||||
|
||||
var skills []api.SkillRef
|
||||
if r.Skills != nil {
|
||||
skills = make([]api.SkillRef, len(r.Skills))
|
||||
copy(skills, r.Skills)
|
||||
}
|
||||
|
||||
return runOptions{
|
||||
Model: r.Model,
|
||||
ParentModel: r.ParentModel,
|
||||
@@ -1263,6 +1370,9 @@ func (r runOptions) Copy() runOptions {
|
||||
Think: think,
|
||||
HideThinking: r.HideThinking,
|
||||
ShowConnect: r.ShowConnect,
|
||||
IsAgent: r.IsAgent,
|
||||
AgentType: r.AgentType,
|
||||
Skills: skills,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1346,6 +1456,65 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load skills for agents
|
||||
var skillsCatalog *skillCatalog
|
||||
if opts.IsAgent && len(opts.Skills) > 0 {
|
||||
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
|
||||
var skillNames []string
|
||||
for _, s := range skillsCatalog.Skills {
|
||||
skillNames = append(skillNames, s.Name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Load MCP servers for agents (from opts and global config)
|
||||
var mcpMgr *mcpManager
|
||||
allMCPs := opts.MCPs
|
||||
|
||||
// Load global MCPs from ~/.ollama/mcp.json
|
||||
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Skip disabled MCPs
|
||||
if srv.Disabled {
|
||||
continue
|
||||
}
|
||||
// Check if already in opts.MCPs (model takes precedence)
|
||||
found := false
|
||||
for _, m := range opts.MCPs {
|
||||
if m.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) > 0 {
|
||||
mcpMgr = newMCPManager()
|
||||
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
|
||||
}
|
||||
if mcpMgr.ToolCount() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
|
||||
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
|
||||
}
|
||||
defer mcpMgr.Shutdown()
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
@@ -1369,6 +1538,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
|
||||
role := "assistant"
|
||||
|
||||
@@ -1409,7 +1579,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
if skillsCatalog != nil || mcpMgr != nil {
|
||||
// Store tool calls for execution after response is complete
|
||||
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||
} else {
|
||||
// No skills catalog or MCP, just display tool calls
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1422,31 +1598,161 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
opts.Format = `"` + opts.Format + `"`
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
// Prepare messages with agent-specific system prompt
|
||||
messages := opts.Messages
|
||||
if skillsCatalog != nil {
|
||||
// Add skills system prompt as the first system message
|
||||
skillsPrompt := skillsCatalog.SystemPrompt()
|
||||
if skillsPrompt != "" {
|
||||
// Insert skills prompt at the beginning, or append to existing system message
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
// Append to existing system message
|
||||
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
|
||||
} else {
|
||||
// Insert new system message at the beginning
|
||||
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
|
||||
messages = append([]api.Message{systemMsg}, messages...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
// Add tools for agents (combine skills and MCP tools)
|
||||
var allTools api.Tools
|
||||
if skillsCatalog != nil {
|
||||
allTools = append(allTools, skillsCatalog.Tools()...)
|
||||
}
|
||||
return nil, err
|
||||
if mcpMgr != nil {
|
||||
allTools = append(allTools, mcpMgr.Tools()...)
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
req.Tools = allTools
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute tool calls and continue the conversation
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Add assistant's tool call message to history (include thinking for proper rendering)
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: fullResponse.String(),
|
||||
Thinking: thinkingContent.String(),
|
||||
ToolCalls: pendingToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Execute each tool call and collect results
|
||||
var toolResults []api.Message
|
||||
for _, call := range pendingToolCalls {
|
||||
// Show what's being executed
|
||||
switch call.Function.Name {
|
||||
case "run_skill_script":
|
||||
skillVal, _ := call.Function.Arguments.Get("skill")
|
||||
skill, _ := skillVal.(string)
|
||||
commandVal, _ := call.Function.Arguments.Get("command")
|
||||
command, _ := commandVal.(string)
|
||||
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
|
||||
case "read_skill_file":
|
||||
skillVal, _ := call.Function.Arguments.Get("skill")
|
||||
skill, _ := skillVal.(string)
|
||||
pathVal, _ := call.Function.Arguments.Get("path")
|
||||
path, _ := pathVal.(string)
|
||||
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
|
||||
}
|
||||
|
||||
var result api.Message
|
||||
var handled bool
|
||||
var err error
|
||||
|
||||
// Try skill catalog first
|
||||
if skillsCatalog != nil {
|
||||
result, handled, err = skillsCatalog.RunToolCall(call)
|
||||
}
|
||||
|
||||
// If not handled by skills, try MCP
|
||||
if !handled && mcpMgr != nil {
|
||||
result, handled, err = mcpMgr.RunToolCall(call)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
// Add error result
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if !handled {
|
||||
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Display tool output
|
||||
if result.Content != "" {
|
||||
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
|
||||
}
|
||||
|
||||
// Add tool result to messages (preserves ToolName, ToolCallID from result)
|
||||
toolResults = append(toolResults, result)
|
||||
}
|
||||
|
||||
// Add tool results to message history
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Reset state for next iteration
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
|
||||
// Start new progress spinner for next API call
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
@@ -1785,10 +2091,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
@@ -1943,6 +2245,8 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
NewSkillCommand(),
|
||||
NewMCPCommand(),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -34,6 +34,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /skills Show available skills")
|
||||
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
@@ -444,6 +447,411 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/skill "):
|
||||
args := strings.Fields(line)
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current skills")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill add <path>")
|
||||
continue
|
||||
}
|
||||
skillPath := args[2]
|
||||
|
||||
// Expand ~ to home directory
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Printf("Error expanding path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
// Make absolute
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error resolving path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify SKILL.md exists
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract skill name from SKILL.md
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading SKILL.md: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillName, _ := extractSkillMetadata(string(content))
|
||||
if skillName == "" {
|
||||
skillName = filepath.Base(absPath)
|
||||
}
|
||||
|
||||
// Check if already added
|
||||
for _, s := range opts.Skills {
|
||||
if s.Name == skillName {
|
||||
fmt.Printf("Skill '%s' is already loaded\n", skillName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Add to skills (using path as Name, no digest for local skills)
|
||||
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
|
||||
opts.IsAgent = true // Enable agent mode if not already
|
||||
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill remove <name>")
|
||||
continue
|
||||
}
|
||||
skillName := args[2]
|
||||
|
||||
found := false
|
||||
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
|
||||
for _, s := range opts.Skills {
|
||||
// Match by name or by path basename
|
||||
name := s.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name)
|
||||
}
|
||||
if name == skillName || s.Name == skillName {
|
||||
found = true
|
||||
fmt.Printf("Removed skill '%s'\n", skillName)
|
||||
} else {
|
||||
newSkills = append(newSkills, s)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fmt.Printf("Skill '%s' not found\n", skillName)
|
||||
} else {
|
||||
opts.Skills = newSkills
|
||||
}
|
||||
|
||||
case "list", "ls":
|
||||
if len(opts.Skills) == 0 {
|
||||
fmt.Println("No skills loaded in this session.")
|
||||
} else {
|
||||
fmt.Println("Skills loaded in this session:")
|
||||
for _, skill := range opts.Skills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
// For local paths, show basename
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (local: " + skill.Name + ")"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/skills"):
|
||||
// Show skills from model (bundled) + session skills
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model skills with session skills
|
||||
allSkills := make([]api.SkillRef, 0)
|
||||
allSkills = append(allSkills, resp.Skills...)
|
||||
|
||||
// Add session skills that aren't already in model skills
|
||||
for _, sessionSkill := range opts.Skills {
|
||||
found := false
|
||||
for _, modelSkill := range resp.Skills {
|
||||
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allSkills = append(allSkills, sessionSkill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allSkills) == 0 {
|
||||
fmt.Println("No skills available.")
|
||||
} else {
|
||||
fmt.Println("Available Skills:")
|
||||
for _, skill := range allSkills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (session)"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/mcp"):
|
||||
args := strings.Fields(line)
|
||||
|
||||
// If just "/mcp" with no args, show all MCP servers
|
||||
if len(args) == 1 {
|
||||
// Show MCPs from model (bundled) + global config
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model MCPs with global config MCPs
|
||||
allMCPs := make([]api.MCPRef, 0)
|
||||
allMCPs = append(allMCPs, resp.MCPs...)
|
||||
|
||||
// Load global config
|
||||
globalConfig, _ := loadMCPConfig()
|
||||
globalMCPNames := make(map[string]bool)
|
||||
|
||||
if globalConfig != nil {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Check if already in model MCPs
|
||||
found := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
globalMCPNames[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) == 0 {
|
||||
fmt.Println("No MCP servers available.")
|
||||
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
|
||||
} else {
|
||||
fmt.Println("Available MCP Servers:")
|
||||
for _, mcp := range allMCPs {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
source := ""
|
||||
disabled := ""
|
||||
// Check if it's from model or global config
|
||||
isFromModel := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == mcp.Name {
|
||||
isFromModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFromModel {
|
||||
source = " (model)"
|
||||
} else if globalMCPNames[mcp.Name] {
|
||||
source = " (global)"
|
||||
// Check if disabled
|
||||
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
|
||||
disabled = " [disabled]"
|
||||
}
|
||||
}
|
||||
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /mcp add <name> <command> [args...]")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
mcpCommand := args[3]
|
||||
mcpArgs := args[4:]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := config.MCPServers[mcpName]; exists {
|
||||
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
|
||||
}
|
||||
|
||||
// Add to global config
|
||||
config.MCPServers[mcpName] = MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: mcpCommand,
|
||||
Args: mcpArgs,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmdLine := mcpCommand
|
||||
if len(mcpArgs) > 0 {
|
||||
cmdLine += " " + strings.Join(mcpArgs, " ")
|
||||
}
|
||||
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
|
||||
fmt.Println("Note: MCP server will be started on next message.")
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp remove <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := config.MCPServers[mcpName]; !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
delete(config.MCPServers, mcpName)
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "disable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp disable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = true
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "enable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp enable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if !srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = false
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
@@ -452,6 +860,20 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "skill", "/skill":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "mcp", "/mcp":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
|
||||
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
|
||||
570
cmd/skill_cmd.go
Normal file
570
cmd/skill_cmd.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillPushHandler handles the skill push command.
|
||||
func SkillPushHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
path := args[1]
|
||||
|
||||
// Expand path
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving path: %w", err)
|
||||
}
|
||||
|
||||
// Validate skill directory
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Parse skill name (will set Kind="skill")
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Create skill layer
|
||||
displayName := n.DisplayShortest()
|
||||
status := fmt.Sprintf("Creating skill layer for %s", displayName)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
layer, err := server.CreateSkillLayer(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill layer: %w", err)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
|
||||
// Create skill manifest
|
||||
manifest, configLayer, err := createSkillManifest(absPath, layer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill manifest: %w", err)
|
||||
}
|
||||
|
||||
// Write manifest locally
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
|
||||
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
|
||||
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
|
||||
|
||||
// Push to registry
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
// For now, we'll use the existing push mechanism
|
||||
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.PushRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Push(context.Background(), req, fn); err != nil {
|
||||
// If push fails, still show success for local creation
|
||||
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillPullHandler handles the skill pull command.
|
||||
func SkillPullHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
req := &api.PullRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Pull(context.Background(), req, fn); err != nil {
|
||||
return fmt.Errorf("pulling skill: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillListHandler handles the skill list command.
|
||||
func SkillListHandler(cmd *cobra.Command, args []string) error {
|
||||
skills, err := listLocalSkills()
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing skills: %w", err)
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
fmt.Println("No skills installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
|
||||
|
||||
for _, skill := range skills {
|
||||
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
|
||||
skill.Namespace,
|
||||
skill.Name,
|
||||
skill.Tag,
|
||||
format.HumanBytes(skill.Size),
|
||||
format.HumanTime(skill.ModifiedAt, "Never"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// SkillRemoveHandler handles the skill rm command.
|
||||
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(manifestPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up empty parent directories
|
||||
dir := filepath.Dir(manifestPath)
|
||||
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
if len(entries) == 0 {
|
||||
os.Remove(dir)
|
||||
dir = filepath.Dir(dir)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillShowHandler handles the skill show command.
|
||||
func SkillShowHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("skill not found: %s", displayName)
|
||||
}
|
||||
return fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Skill: %s\n\n", displayName)
|
||||
|
||||
fmt.Println("Layers:")
|
||||
for _, layer := range manifest.Layers {
|
||||
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
|
||||
}
|
||||
|
||||
// Try to read and display SKILL.md content
|
||||
if len(manifest.Layers) > 0 {
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == server.MediaTypeSkill {
|
||||
skillPath, err := server.GetSkillsPath(layer.Digest)
|
||||
if err == nil {
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
if content, err := os.ReadFile(skillMdPath); err == nil {
|
||||
fmt.Println("\nContent:")
|
||||
fmt.Println(string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillInfo represents information about an installed skill.
|
||||
type SkillInfo struct {
|
||||
Namespace string
|
||||
Name string
|
||||
Tag string
|
||||
Size int64
|
||||
ModifiedAt time.Time
|
||||
}
|
||||
|
||||
// listLocalSkills returns a list of locally installed skills.
|
||||
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
|
||||
// where kind is "skill".
|
||||
func listLocalSkills() ([]SkillInfo, error) {
|
||||
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
|
||||
|
||||
var skills []SkillInfo
|
||||
|
||||
// Walk through all registries
|
||||
registries, err := os.ReadDir(manifestsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return skills, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, registry := range registries {
|
||||
if !registry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk namespaces
|
||||
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
if !namespace.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk kinds looking for "skill"
|
||||
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kind := range kinds {
|
||||
if !kind.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process skill kind
|
||||
if kind.Name() != server.SkillNamespace {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk skill names (model names)
|
||||
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, skillName := range skillNames {
|
||||
if !skillName.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk tags
|
||||
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
|
||||
fi, err := os.Stat(manifestPath)
|
||||
if err != nil || fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read manifest to get size
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Layers {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
|
||||
// Build display name using model.Name
|
||||
n := model.Name{
|
||||
Host: registry.Name(),
|
||||
Namespace: namespace.Name(),
|
||||
Kind: kind.Name(),
|
||||
Model: skillName.Name(),
|
||||
Tag: tag.Name(),
|
||||
}
|
||||
|
||||
skills = append(skills, SkillInfo{
|
||||
Namespace: n.Namespace + "/" + n.Kind,
|
||||
Name: n.Model,
|
||||
Tag: n.Tag,
|
||||
Size: totalSize,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
// createSkillManifest creates a manifest for a standalone skill.
|
||||
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
|
||||
// Read SKILL.md to extract metadata
|
||||
skillMdPath := filepath.Join(skillDir, "SKILL.md")
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Extract name and description from frontmatter
|
||||
name, description := extractSkillMetadata(string(content))
|
||||
if name == "" {
|
||||
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
|
||||
}
|
||||
|
||||
// Create config
|
||||
config := map[string]any{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"architecture": "amd64",
|
||||
"os": "linux",
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer
|
||||
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating config layer: %w", err)
|
||||
}
|
||||
|
||||
manifest := &server.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: configLayer,
|
||||
Layers: []server.Layer{layer},
|
||||
}
|
||||
|
||||
return manifest, &configLayer, nil
|
||||
}
|
||||
|
||||
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
|
||||
func extractSkillMetadata(content string) (name, description string) {
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
inFrontmatter := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if trimmed == "---" {
|
||||
if !inFrontmatter {
|
||||
inFrontmatter = true
|
||||
continue
|
||||
} else {
|
||||
break // End of frontmatter
|
||||
}
|
||||
}
|
||||
|
||||
if inFrontmatter {
|
||||
if strings.HasPrefix(trimmed, "name:") {
|
||||
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
|
||||
} else if strings.HasPrefix(trimmed, "description:") {
|
||||
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return name, description
|
||||
}
|
||||
|
||||
// NewSkillCommand creates the skill parent command with subcommands.
|
||||
func NewSkillCommand() *cobra.Command {
|
||||
skillCmd := &cobra.Command{
|
||||
Use: "skill",
|
||||
Short: "Manage skills",
|
||||
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
|
||||
}
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push NAME[:TAG] PATH",
|
||||
Short: "Push a skill to a registry",
|
||||
Long: "Package a local skill directory and push it to a registry",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPushHandler,
|
||||
}
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull NAME[:TAG]",
|
||||
Short: "Pull a skill from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPullHandler,
|
||||
}
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List installed skills",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: SkillListHandler,
|
||||
}
|
||||
|
||||
rmCmd := &cobra.Command{
|
||||
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
|
||||
Aliases: []string{"remove", "delete"},
|
||||
Short: "Remove a skill",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: SkillRemoveHandler,
|
||||
}
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show NAME[:TAG]",
|
||||
Short: "Show skill details",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: SkillShowHandler,
|
||||
}
|
||||
|
||||
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
|
||||
|
||||
return skillCmd
|
||||
}
|
||||
591
cmd/skills.go
Normal file
591
cmd/skills.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/server"
|
||||
)
|
||||
|
||||
const (
|
||||
skillFileName = "SKILL.md"
|
||||
maxSkillDescription = 1024
|
||||
maxSkillNameLength = 64
|
||||
)
|
||||
|
||||
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
|
||||
type skillMetadata struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type skillDefinition struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string // Full SKILL.md content (without frontmatter)
|
||||
Dir string
|
||||
SkillPath string
|
||||
}
|
||||
|
||||
type skillCatalog struct {
|
||||
Skills []skillDefinition
|
||||
byName map[string]skillDefinition
|
||||
}
|
||||
|
||||
func loadSkills(paths []string) (*skillCatalog, error) {
|
||||
if len(paths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
for _, root := range paths {
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skills directory %q: %w", root, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("skills path %q is not a directory", root)
|
||||
}
|
||||
|
||||
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
|
||||
// Skills can be referenced by:
|
||||
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
|
||||
// - Name (local path): loaded from the filesystem (for development)
|
||||
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
|
||||
if len(refs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
|
||||
for _, ref := range refs {
|
||||
var skillDir string
|
||||
|
||||
if ref.Digest != "" {
|
||||
// Load from extracted skill cache
|
||||
path, err := server.GetSkillsPath(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
|
||||
}
|
||||
|
||||
// Check if skill is already extracted
|
||||
skillMdPath := filepath.Join(path, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
|
||||
// Try to extract the skill blob
|
||||
path, err = server.ExtractSkillBlob(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
|
||||
}
|
||||
}
|
||||
|
||||
skillDir = path
|
||||
} else if ref.Name != "" {
|
||||
// Check if this is a local path or a registry reference
|
||||
if !server.IsLocalSkillPath(ref.Name) {
|
||||
// Registry reference without a digest - skill needs to be pulled first
|
||||
// This happens when an agent references a skill that hasn't been bundled
|
||||
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
|
||||
}
|
||||
|
||||
// Local path - resolve it
|
||||
skillPath := ref.Name
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
// Check if this is a directory containing skills or a single skill
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
|
||||
skillMdPath := filepath.Join(absPath, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); err == nil {
|
||||
// Direct skill directory
|
||||
skillDir = absPath
|
||||
} else {
|
||||
// Parent directory - walk to find skill subdirectories
|
||||
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillSubDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillSubDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
|
||||
}
|
||||
} else {
|
||||
// Both empty - skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the skill from skillDir if set
|
||||
if skillDir != "" {
|
||||
skillMdPath := filepath.Join(skillDir, skillFileName)
|
||||
skill, err := parseSkillFile(skillMdPath, skillDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
|
||||
rawContent, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
var meta skillMetadata
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
|
||||
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
|
||||
}
|
||||
|
||||
if err := validateSkillMetadata(meta, skillDir); err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
absDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
return skillDefinition{
|
||||
Name: meta.Name,
|
||||
Description: meta.Description,
|
||||
Content: bodyContent,
|
||||
Dir: absDir,
|
||||
SkillPath: absPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
if !scanner.Scan() {
|
||||
return "", "", errors.New("empty SKILL.md")
|
||||
}
|
||||
if strings.TrimSpace(scanner.Text()) != "---" {
|
||||
return "", "", errors.New("missing YAML frontmatter")
|
||||
}
|
||||
|
||||
var fmLines []string
|
||||
foundEnd := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "---" {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
fmLines = append(fmLines, line)
|
||||
}
|
||||
if !foundEnd {
|
||||
return "", "", errors.New("frontmatter not terminated")
|
||||
}
|
||||
|
||||
// Collect remaining content as body
|
||||
var bodyLines []string
|
||||
for scanner.Scan() {
|
||||
bodyLines = append(bodyLines, scanner.Text())
|
||||
}
|
||||
|
||||
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
|
||||
}
|
||||
|
||||
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
|
||||
name := strings.TrimSpace(meta.Name)
|
||||
description := strings.TrimSpace(meta.Description)
|
||||
|
||||
switch {
|
||||
case name == "":
|
||||
return errors.New("missing skill name")
|
||||
case len(name) > maxSkillNameLength:
|
||||
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
|
||||
case !skillNamePattern.MatchString(name):
|
||||
return fmt.Errorf("invalid skill name %q", name)
|
||||
}
|
||||
|
||||
if description == "" {
|
||||
return errors.New("missing skill description")
|
||||
}
|
||||
if len(description) > maxSkillDescription {
|
||||
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
|
||||
}
|
||||
|
||||
// Skip directory name check for digest-based paths (extracted from blobs)
|
||||
dirName := filepath.Base(skillDir)
|
||||
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
|
||||
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *skillCatalog) SystemPrompt() string {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("# Skills\n\n")
|
||||
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
|
||||
b.WriteString("## Available Tools\n\n")
|
||||
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
|
||||
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
|
||||
|
||||
for _, skill := range c.Skills {
|
||||
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
|
||||
fmt.Fprintf(&b, "%s\n\n", skill.Content)
|
||||
b.WriteString("---\n\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (c *skillCatalog) Tools() api.Tools {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
runScriptProps := api.NewToolPropertiesMap()
|
||||
runScriptProps.Set("skill", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the script",
|
||||
})
|
||||
runScriptProps.Set("command", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
|
||||
})
|
||||
|
||||
readFileProps := api.NewToolPropertiesMap()
|
||||
readFileProps.Set("skill", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the file",
|
||||
})
|
||||
readFileProps.Set("path", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The relative path to the file within the skill directory",
|
||||
})
|
||||
|
||||
return api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "run_skill_script",
|
||||
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "command"},
|
||||
Properties: runScriptProps,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "read_skill_file",
|
||||
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "path"},
|
||||
Properties: readFileProps,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
||||
switch call.Function.Name {
|
||||
case "read_skill_file":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
relPath, err := requireStringArg(call.Function.Arguments, "path")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
content, err := readSkillFile(skill.Dir, relPath)
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
return toolMessage(call, content), true, nil
|
||||
|
||||
case "run_skill_script":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
command, err := requireStringArg(call.Function.Arguments, "command")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
output, err := runSkillScript(skill.Dir, command)
|
||||
if err != nil {
|
||||
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
|
||||
}
|
||||
return toolMessage(call, output), true, nil
|
||||
|
||||
default:
|
||||
return api.Message{}, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// runSkillScript executes a shell command within a skill's directory.
|
||||
//
|
||||
// SECURITY LIMITATIONS (TODO):
|
||||
// - No sandboxing: commands run with full user permissions
|
||||
// - No path validation: model can run any command, not just scripts in skill dir
|
||||
// - Shell injection risk: sh -c is used, malicious input could be crafted
|
||||
// - No executable allowlist: any program can be called (curl, rm, etc.)
|
||||
// - No environment isolation: scripts inherit full environment variables
|
||||
//
|
||||
// POTENTIAL IMPROVEMENTS:
|
||||
// - Restrict commands to only reference files within skill directory
|
||||
// - Allowlist specific executables (python3, node, bash)
|
||||
// - Use sandboxing (Docker, nsjail, seccomp)
|
||||
// - Require explicit script registration in SKILL.md frontmatter
|
||||
// - Add per-skill configurable timeouts
|
||||
func runSkillScript(skillDir, command string) (string, error) {
|
||||
// Validate the skill directory exists
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(absSkillDir); err != nil {
|
||||
return "", fmt.Errorf("skill directory not found: %w", err)
|
||||
}
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", command)
|
||||
cmd.Dir = absSkillDir
|
||||
|
||||
// Inject the current working directory (where ollama run was called from)
|
||||
// as an environment variable so scripts can reference files in that directory
|
||||
workingDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
|
||||
|
||||
// Capture both stdout and stderr
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
// Combine output
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
if output != "" {
|
||||
output += "\n"
|
||||
}
|
||||
output += stderr.String()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return output, fmt.Errorf("command timed out after 30 seconds")
|
||||
}
|
||||
return output, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func readSkillFile(skillDir, relPath string) (string, error) {
|
||||
relPath = filepath.Clean(strings.TrimSpace(relPath))
|
||||
if relPath == "" {
|
||||
return "", errors.New("path is required")
|
||||
}
|
||||
if filepath.IsAbs(relPath) {
|
||||
return "", errors.New("path must be relative to the skill directory")
|
||||
}
|
||||
|
||||
target := filepath.Join(skillDir, relPath)
|
||||
absTarget, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel, err := filepath.Rel(absSkillDir, absTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") {
|
||||
return "", errors.New("path escapes the skill directory")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absTarget)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
|
||||
value, ok := args.Get(name)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required argument %q", name)
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("argument %q must be a string", name)
|
||||
}
|
||||
if strings.TrimSpace(str) == "" {
|
||||
return "", fmt.Errorf("argument %q cannot be empty", name)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func toolMessage(call api.ToolCall, content string) api.Message {
|
||||
msg := api.Message{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolName: call.Function.Name,
|
||||
}
|
||||
if call.ID != "" {
|
||||
msg.ToolCallID = call.ID
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -6,14 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -21,13 +18,8 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
@@ -41,94 +33,8 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() KV {
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := KV{
|
||||
kv := ggml.KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -206,7 +107,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ofs.Config) KV
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -214,7 +115,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
return errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -7,7 +9,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,6 +3,8 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -53,7 +55,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -19,13 +18,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
@@ -34,7 +34,7 @@ type olmoModel struct {
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
for k, v := range kv {
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV KV
|
||||
BaseKV map[string]any
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -32,9 +32,7 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
"options": ["copy"]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -54,9 +52,7 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
"languages": ["curl"]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -101,7 +97,6 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -144,8 +139,7 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
"/api/openai-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
548
docs/skills.md
Normal file
548
docs/skills.md
Normal file
@@ -0,0 +1,548 @@
|
||||
# Ollama Skills
|
||||
|
||||
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Creating a Skill
|
||||
|
||||
Create a directory with a `SKILL.md` file:
|
||||
|
||||
```
|
||||
my-skill/
|
||||
├── SKILL.md # Required: Instructions for the agent
|
||||
└── scripts/ # Optional: Executable scripts
|
||||
└── run.py
|
||||
```
|
||||
|
||||
The `SKILL.md` file must have YAML frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: A brief description of what this skill does
|
||||
---
|
||||
|
||||
# My Skill
|
||||
|
||||
## Purpose
|
||||
Explain what this skill does and when to use it.
|
||||
|
||||
## Instructions
|
||||
Step-by-step instructions for the agent on how to use this skill.
|
||||
|
||||
## Examples
|
||||
Show example inputs and expected outputs.
|
||||
```
|
||||
|
||||
### Using Skills in an Agent
|
||||
|
||||
Reference skills in your Agentfile:
|
||||
|
||||
```dockerfile
|
||||
FROM llama3.2:3b
|
||||
AGENT_TYPE conversational
|
||||
|
||||
# Local skill (bundled with agent)
|
||||
SKILL ./path/to/my-skill
|
||||
|
||||
# Registry skill (pulled from ollama.com)
|
||||
SKILL library/skill/calculator:1.0.0
|
||||
|
||||
# User skill from registry
|
||||
SKILL myname/skill/calculator:1.0.0
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
```
|
||||
|
||||
### Managing Skills
|
||||
|
||||
```bash
|
||||
# Push a skill to the registry (uses your namespace)
|
||||
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
|
||||
|
||||
# Pull a skill from the official library
|
||||
ollama skill pull skill/calculator:1.0.0
|
||||
|
||||
# Pull a skill from a user's namespace
|
||||
ollama skill pull myname/skill/calculator:1.0.0
|
||||
|
||||
# List installed skills
|
||||
ollama skill list
|
||||
|
||||
# Show skill details
|
||||
ollama skill show skill/calculator:1.0.0
|
||||
|
||||
# Remove a skill
|
||||
ollama skill rm skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
### Dynamic Skills in Chat
|
||||
|
||||
You can add and remove skills dynamically during an interactive chat session:
|
||||
|
||||
```
|
||||
>>> /skills
|
||||
Available Skills:
|
||||
calculator (sha256:abc123def456...)
|
||||
|
||||
>>> /skill add ./my-local-skill
|
||||
Added skill 'my-skill' from ./my-local-skill
|
||||
|
||||
>>> /skill list
|
||||
Skills loaded in this session:
|
||||
my-skill (local: /path/to/my-local-skill)
|
||||
|
||||
>>> /skill remove my-skill
|
||||
Removed skill 'my-skill'
|
||||
```
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/skills` | Show all available skills (model + session) |
|
||||
| `/skill add <path>` | Add a skill from a local path |
|
||||
| `/skill remove <name>` | Remove a skill by name |
|
||||
| `/skill list` | List skills loaded in this session |
|
||||
|
||||
Dynamic skills take effect on the next message. This is useful for:
|
||||
- Testing skills during development
|
||||
- Temporarily adding capabilities to a model
|
||||
- Experimenting with skill combinations
|
||||
|
||||
## Skill Reference Formats
|
||||
|
||||
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
|
||||
|
||||
| Format | Example | Description |
|
||||
|--------|---------|-------------|
|
||||
| Local path | `./skills/calc` | Bundled with agent at create time |
|
||||
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
|
||||
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
|
||||
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
|
||||
|
||||
The `kind` field distinguishes skills from models:
|
||||
- `skill` - Skill packages
|
||||
- `agent` - Agent packages (future)
|
||||
- (empty) - Regular models
|
||||
|
||||
## SKILL.md Structure
|
||||
|
||||
### Required Frontmatter
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: skill-name # Must match directory name
|
||||
description: Brief description of the skill
|
||||
---
|
||||
```
|
||||
|
||||
### Recommended Sections
|
||||
|
||||
1. **Purpose**: What the skill does and when to use it
|
||||
2. **When to use**: Trigger conditions for the agent
|
||||
3. **Instructions**: Step-by-step usage guide
|
||||
4. **Examples**: Input/output examples
|
||||
5. **Scripts**: Documentation for any bundled scripts
|
||||
|
||||
### Example: Calculator Skill
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: calculator
|
||||
description: Performs mathematical calculations using Python
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Purpose
|
||||
This skill performs mathematical calculations using a bundled Python script.
|
||||
|
||||
## When to use
|
||||
- User asks to calculate something
|
||||
- User wants to do math operations
|
||||
- Any arithmetic is needed
|
||||
|
||||
## Instructions
|
||||
1. When calculation is needed, use the `run_skill_script` tool
|
||||
2. Call: `python3 scripts/calculate.py "<expression>"`
|
||||
3. Return the result to the user
|
||||
|
||||
## Examples
|
||||
|
||||
**Input**: "What is 25 * 4?"
|
||||
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
|
||||
**Output**: "25 * 4 = 100"
|
||||
```
|
||||
|
||||
## Storage Layout
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
├── blobs/
|
||||
│ └── sha256-<digest> # Skill tar.gz blob
|
||||
├── manifests/
|
||||
│ └── registry.ollama.ai/
|
||||
│ └── skill/ # Library skills
|
||||
│ └── calculator/
|
||||
│ └── 1.0.0
|
||||
│ └── skill-username/ # User skills
|
||||
│ └── my-skill/
|
||||
│ └── latest
|
||||
└── skills/
|
||||
└── sha256-<digest>/ # Extracted skill cache
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Security Considerations
|
||||
|
||||
## Current State (Development)
|
||||
|
||||
The current implementation has several security considerations that need to be addressed before production use.
|
||||
|
||||
### 1. Script Execution
|
||||
|
||||
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts run with the same permissions as the Ollama process
|
||||
- No sandboxing or isolation
|
||||
- Full filesystem access
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Sandbox script execution (containers, seccomp, etc.)
|
||||
- [ ] Resource limits (CPU, memory, time)
|
||||
- [ ] Filesystem isolation (read-only mounts, restricted paths)
|
||||
- [ ] Network policy controls
|
||||
- [ ] Capability dropping
|
||||
|
||||
### 2. Skill Provenance
|
||||
|
||||
**Risk**: Malicious skills could be pushed to the registry.
|
||||
|
||||
**Current behavior**:
|
||||
- No code signing or verification
|
||||
- No malware scanning
|
||||
- Trust based on namespace ownership
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Skill signing with author keys
|
||||
- [ ] Registry-side malware scanning
|
||||
- [ ] Content policy enforcement
|
||||
- [ ] Reputation system for skill authors
|
||||
|
||||
### 3. Namespace Squatting
|
||||
|
||||
**Risk**: Malicious actors could register skill names that impersonate official tools.
|
||||
|
||||
**Current behavior**:
|
||||
- First-come-first-served namespace registration
|
||||
- No verification of skill names
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Reserved namespace list (official tools, common names)
|
||||
- [ ] Trademark/name verification for popular skills
|
||||
- [ ] Clear namespacing conventions
|
||||
|
||||
### 4. Supply Chain Attacks
|
||||
|
||||
**Risk**: Compromised skills could inject malicious code into agents.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills pulled without integrity verification beyond digest
|
||||
- No dependency tracking
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] SBOM (Software Bill of Materials) for skills
|
||||
- [ ] Dependency vulnerability scanning
|
||||
- [ ] Pinned versions in Agentfiles
|
||||
- [ ] Audit logging of skill usage
|
||||
|
||||
### 5. Data Exfiltration
|
||||
|
||||
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills have access to conversation context
|
||||
- Scripts can make network requests
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Network egress controls
|
||||
- [ ] Sensitive data detection/masking
|
||||
- [ ] Audit logging of script network activity
|
||||
- [ ] User consent for data access
|
||||
|
||||
### 6. Privilege Escalation
|
||||
|
||||
**Risk**: Skills could escalate privileges through script execution.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts inherit Ollama process privileges
|
||||
- No capability restrictions
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Run scripts as unprivileged user
|
||||
- [ ] Drop all capabilities
|
||||
- [ ] Mandatory access controls (SELinux/AppArmor)
|
||||
|
||||
## Recommended Security Model
|
||||
|
||||
### Skill Trust Levels
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Level 0: Untrusted (default) │
|
||||
│ - No script execution │
|
||||
│ - Instructions only │
|
||||
│ - Safe for any skill │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 1: Sandboxed │
|
||||
│ - Scripts run in isolated container │
|
||||
│ - No network access │
|
||||
│ - Read-only filesystem │
|
||||
│ - Resource limits enforced │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 2: Trusted │
|
||||
│ - Scripts run with network access │
|
||||
│ - Can write to designated directories │
|
||||
│ - Requires explicit user approval │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 3: Privileged (admin only) │
|
||||
│ - Full host access │
|
||||
│ - System administration skills │
|
||||
│ - Requires admin approval │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Skill Manifest Security Fields (Future)
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: A skill description
|
||||
security:
|
||||
trust_level: sandboxed
|
||||
permissions:
|
||||
- network:read # Can make HTTP GET requests
|
||||
- filesystem:read:/data # Can read from /data
|
||||
resource_limits:
|
||||
max_memory: 256MB
|
||||
max_cpu_time: 30s
|
||||
max_disk: 100MB
|
||||
signature: sha256:abc... # Author signature
|
||||
---
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Future Considerations
|
||||
|
||||
## Feature Roadmap
|
||||
|
||||
### Phase 1: Foundation (Current)
|
||||
- [x] Skill bundling with agents
|
||||
- [x] Local skill development
|
||||
- [x] Basic CLI commands (push, pull, list, rm, show)
|
||||
- [x] Registry blob storage
|
||||
- [ ] Registry namespace configuration
|
||||
|
||||
### Phase 2: Security
|
||||
- [ ] Script sandboxing
|
||||
- [ ] Permission model
|
||||
- [ ] Skill signing
|
||||
- [ ] Audit logging
|
||||
|
||||
### Phase 3: Discovery
|
||||
- [ ] Skill search on ollama.com
|
||||
- [ ] Skill ratings and reviews
|
||||
- [ ] Usage analytics
|
||||
- [ ] Featured/trending skills
|
||||
|
||||
### Phase 4: Advanced Features
|
||||
- [ ] Skill dependencies
|
||||
- [ ] Skill versioning constraints
|
||||
- [ ] Skill composition (skills using skills)
|
||||
- [ ] Skill testing framework
|
||||
|
||||
## Open Questions
|
||||
|
||||
### 1. Skill Execution Model
|
||||
|
||||
**Question**: How should skills execute scripts?
|
||||
|
||||
Options:
|
||||
- **A) In-process**: Fast but unsafe
|
||||
- **B) Subprocess**: Current approach, moderate isolation
|
||||
- **C) Container**: Good isolation, requires container runtime
|
||||
- **D) WASM**: Portable and safe, limited capabilities
|
||||
- **E) Remote execution**: Offload to secure service
|
||||
|
||||
### 2. Skill Versioning
|
||||
|
||||
**Question**: How strict should version pinning be?
|
||||
|
||||
Options:
|
||||
- **A) Always latest**: Simple but risky
|
||||
- **B) Semantic versioning**: `^1.0.0` allows minor updates
|
||||
- **C) Exact pinning**: `=1.0.0` requires explicit updates
|
||||
- **D) Digest pinning**: `@sha256:abc` immutable reference
|
||||
|
||||
### 3. Skill Permissions
|
||||
|
||||
**Question**: How should users grant permissions to skills?
|
||||
|
||||
Options:
|
||||
- **A) All or nothing**: Accept all permissions or don't use
|
||||
- **B) Granular consent**: Approve each permission individually
|
||||
- **C) Trust levels**: Pre-defined permission bundles
|
||||
- **D) Runtime prompts**: Ask when permission is first used
|
||||
|
||||
### 4. Skill Discovery
|
||||
|
||||
**Question**: How should users find skills?
|
||||
|
||||
Options:
|
||||
- **A) Central registry only**: ollama.com/skills
|
||||
- **B) Federated registries**: Multiple skill sources
|
||||
- **C) Git repositories**: Pull from GitHub, etc.
|
||||
- **D) All of the above**: Multiple discovery mechanisms
|
||||
|
||||
### 5. Skill Monetization
|
||||
|
||||
**Question**: Should skill authors be able to monetize?
|
||||
|
||||
Options:
|
||||
- **A) Free only**: All skills are free and open
|
||||
- **B) Paid skills**: Authors can charge for skills
|
||||
- **C) Freemium**: Free tier with paid features
|
||||
- **D) Donations**: Voluntary support for authors
|
||||
|
||||
### 6. Skill Updates
|
||||
|
||||
**Question**: How should skill updates be handled?
|
||||
|
||||
Options:
|
||||
- **A) Manual**: User explicitly updates
|
||||
- **B) Auto-update**: Always use latest
|
||||
- **C) Notify**: Alert user to available updates
|
||||
- **D) Policy-based**: Organization controls update policy
|
||||
|
||||
## API Considerations
|
||||
|
||||
### Skill Metadata API
|
||||
|
||||
```
|
||||
GET /api/skills
|
||||
GET /api/skills/:namespace/:name
|
||||
GET /api/skills/:namespace/:name/versions
|
||||
GET /api/skills/:namespace/:name/readme
|
||||
```
|
||||
|
||||
### Skill Execution API
|
||||
|
||||
```
|
||||
POST /api/skills/:namespace/:name/execute
|
||||
{
|
||||
"command": "python3 scripts/run.py",
|
||||
"args": ["--input", "data"],
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
### Skill Permissions API
|
||||
|
||||
```
|
||||
GET /api/skills/:namespace/:name/permissions
|
||||
POST /api/skills/:namespace/:name/permissions/grant
|
||||
DELETE /api/skills/:namespace/:name/permissions/revoke
|
||||
```
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
### Skill Testing Framework
|
||||
|
||||
```bash
|
||||
# Run skill tests
|
||||
ollama skill test ./my-skill
|
||||
|
||||
# Test with specific model
|
||||
ollama skill test ./my-skill --model llama3.2:3b
|
||||
|
||||
# Generate test report
|
||||
ollama skill test ./my-skill --report
|
||||
```
|
||||
|
||||
### Test File Format
|
||||
|
||||
```yaml
|
||||
# my-skill/tests/test.yaml
|
||||
tests:
|
||||
- name: "basic calculation"
|
||||
input: "What is 2 + 2?"
|
||||
expect:
|
||||
contains: "4"
|
||||
tool_called: "run_skill_script"
|
||||
|
||||
- name: "complex expression"
|
||||
input: "Calculate 15% of 200"
|
||||
expect:
|
||||
contains: "30"
|
||||
```
|
||||
|
||||
## Compatibility Considerations
|
||||
|
||||
### Minimum Ollama Version
|
||||
|
||||
Skills should declare minimum Ollama version:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
requires:
|
||||
ollama: ">=0.4.0"
|
||||
---
|
||||
```
|
||||
|
||||
### Model Compatibility
|
||||
|
||||
Skills may require specific model capabilities:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: vision-skill
|
||||
requires:
|
||||
capabilities:
|
||||
- vision
|
||||
- tools
|
||||
---
|
||||
```
|
||||
|
||||
## Migration Path
|
||||
|
||||
### From Local to Registry
|
||||
|
||||
```bash
|
||||
# Develop locally
|
||||
SKILL ./my-skill
|
||||
|
||||
# Push when ready
|
||||
ollama skill push myname/my-skill:1.0.0 ./my-skill
|
||||
|
||||
# Update Agentfile
|
||||
SKILL skill/myname/my-skill:1.0.0
|
||||
```
|
||||
|
||||
### Version Upgrades
|
||||
|
||||
```bash
|
||||
# Check for updates
|
||||
ollama skill outdated
|
||||
|
||||
# Update specific skill
|
||||
ollama skill update calculator:1.0.0
|
||||
|
||||
# Update all skills
|
||||
ollama skill update --all
|
||||
```
|
||||
@@ -148,6 +148,16 @@ func Remotes() []string {
|
||||
return r
|
||||
}
|
||||
|
||||
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
|
||||
// Returns empty slice if not configured.
|
||||
func Skills() []string {
|
||||
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
|
||||
func BoolWithDefault(k string) func(defaultValue bool) bool {
|
||||
return func(defaultValue bool) bool {
|
||||
if s := Var(k); s != "" {
|
||||
@@ -317,6 +327,9 @@ func AsMap() map[string]EnvVar {
|
||||
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
|
||||
}
|
||||
|
||||
// Skills configuration would go here when added
|
||||
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -13,8 +11,4 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
20
go.mod
20
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -30,8 +30,8 @@ require (
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.31.0
|
||||
golang.org/x/tools v0.40.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -81,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,584 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err anthropic.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
stream := true
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "basic message",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with system prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are helpful.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop": []string{"\n", "END"},
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &stream,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testProps(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tool result",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with thinking enabled",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "model is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing max_tokens error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "max_tokens is required and must be positive",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing messages error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "messages is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_use missing id error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "name": "test"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "tool_use block missing required 'id' field",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Type != "" {
|
||||
// Expect error
|
||||
if resp.Code == http.StatusOK {
|
||||
t.Fatalf("expected error response, got 200 OK")
|
||||
}
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
if errResp.Type != tc.err.Type {
|
||||
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tc.err.Error.Type {
|
||||
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tc.err.Error.Message {
|
||||
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("request was not captured")
|
||||
}
|
||||
|
||||
// Compare relevant fields
|
||||
if capturedRequest.Model != tc.req.Model {
|
||||
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
||||
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
||||
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
||||
if *tc.req.Stream != *capturedRequest.Stream {
|
||||
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.req.Think != nil {
|
||||
if capturedRequest.Think == nil {
|
||||
t.Error("expected Think to be set")
|
||||
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
||||
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("streaming sets correct headers", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Check headers were set
|
||||
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
||||
}
|
||||
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate Ollama response
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result anthropic.MessagesResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
||||
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
||||
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
errorPayload any
|
||||
wantErrorType string
|
||||
wantMessage string
|
||||
}{
|
||||
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
||||
{
|
||||
name: "404 with gin.H error (model not found)",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "400 with gin.H error (bad request)",
|
||||
statusCode: http.StatusBadRequest,
|
||||
errorPayload: gin.H{"error": "model is required"},
|
||||
wantErrorType: "invalid_request_error",
|
||||
wantMessage: "model is required",
|
||||
},
|
||||
{
|
||||
name: "500 with gin.H error (internal error)",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
errorPayload: gin.H{"error": "something went wrong"},
|
||||
wantErrorType: "api_error",
|
||||
wantMessage: "something went wrong",
|
||||
},
|
||||
{
|
||||
name: "404 with api.StatusError",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: api.StatusError{
|
||||
StatusCode: http.StatusNotFound,
|
||||
ErrorMessage: "model not found via StatusError",
|
||||
},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model not found via StatusError",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate what routes.go does - set status and write error JSON
|
||||
data, _ := json.Marshal(tt.errorPayload)
|
||||
c.Writer.WriteHeader(tt.statusCode)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.statusCode {
|
||||
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tt.wantErrorType {
|
||||
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tt.wantMessage {
|
||||
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
126
parser/parser.go
126
parser/parser.go
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -58,6 +59,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
|
||||
var messages []api.Message
|
||||
var licenses []string
|
||||
var skills []api.SkillRef
|
||||
var mcps []api.MCPRef
|
||||
params := make(map[string]any)
|
||||
|
||||
for _, c := range f.Commands {
|
||||
@@ -118,6 +121,32 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
case "skill":
|
||||
skillName := c.Args
|
||||
// Expand local paths relative to the Agentfile directory
|
||||
if isLocalPath(skillName) {
|
||||
expanded, err := expandPath(skillName, relativeDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expanding skill path %q: %w", skillName, err)
|
||||
}
|
||||
skillName = expanded
|
||||
}
|
||||
skills = append(skills, api.SkillRef{Name: skillName})
|
||||
case "mcp":
|
||||
mcpRef, err := parseMCPArg(c.Args, relativeDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid MCP: %w", err)
|
||||
}
|
||||
mcps = append(mcps, mcpRef)
|
||||
case "agent_type":
|
||||
// Handle "AGENT TYPE conversational" -> strip "TYPE " prefix
|
||||
args := c.Args
|
||||
if strings.HasPrefix(strings.ToLower(args), "type ") {
|
||||
args = strings.TrimSpace(args[5:])
|
||||
}
|
||||
req.AgentType = args
|
||||
case "entrypoint":
|
||||
req.Entrypoint = c.Args
|
||||
default:
|
||||
if slices.Contains(deprecatedParameters, c.Name) {
|
||||
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
|
||||
@@ -150,6 +179,12 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
if len(licenses) > 0 {
|
||||
req.License = licenses
|
||||
}
|
||||
if len(skills) > 0 {
|
||||
req.Skills = skills
|
||||
}
|
||||
if len(mcps) > 0 {
|
||||
req.MCPs = mcps
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
@@ -333,7 +368,7 @@ func (c Command) String() string {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type", "entrypoint":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
@@ -359,7 +394,7 @@ const (
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", \"agent_type\", \"mcp\", or \"entrypoint\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
@@ -423,6 +458,9 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
switch s := strings.ToLower(b.String()); s {
|
||||
case "from":
|
||||
cmd.Name = "model"
|
||||
case "agent":
|
||||
// "AGENT TYPE" -> "agent_type", consume next word
|
||||
cmd.Name = "agent_type"
|
||||
case "parameter":
|
||||
// transition to stateParameter which sets command name
|
||||
next = stateParameter
|
||||
@@ -500,6 +538,10 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
if cmd.Name == "model" {
|
||||
return &f, nil
|
||||
}
|
||||
// Allow entrypoint-only agents without FROM
|
||||
if cmd.Name == "entrypoint" {
|
||||
return &f, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMissingFrom
|
||||
@@ -518,7 +560,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||
}
|
||||
case stateName:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
case isAlpha(r), r == '_':
|
||||
return stateName, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
@@ -619,7 +661,7 @@ func isValidMessageRole(role string) bool {
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type", "agent", "mcp", "entrypoint":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -666,3 +708,79 @@ func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User
|
||||
func expandPath(path, relativeDir string) (string, error) {
|
||||
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
|
||||
}
|
||||
|
||||
// parseMCPArg parses MCP command arguments.
|
||||
// Supports two formats:
|
||||
//
|
||||
// JSON: {"name": "web-search", "command": "uv", "args": ["run", "./script.py"]}
|
||||
// Simple: web-search uv run ./script.py (name, command, args...)
|
||||
func parseMCPArg(args string, relativeDir string) (api.MCPRef, error) {
|
||||
args = strings.TrimSpace(args)
|
||||
if args == "" {
|
||||
return api.MCPRef{}, errors.New("MCP requires arguments")
|
||||
}
|
||||
|
||||
// Try JSON format first
|
||||
if strings.HasPrefix(args, "{") {
|
||||
var ref api.MCPRef
|
||||
if err := json.Unmarshal([]byte(args), &ref); err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
if ref.Name == "" {
|
||||
return api.MCPRef{}, errors.New("MCP name is required")
|
||||
}
|
||||
if ref.Command == "" {
|
||||
return api.MCPRef{}, errors.New("MCP command is required")
|
||||
}
|
||||
if ref.Type == "" {
|
||||
ref.Type = "stdio"
|
||||
}
|
||||
// Expand relative paths in args
|
||||
for i, arg := range ref.Args {
|
||||
if isLocalPath(arg) {
|
||||
expanded, err := expandPath(arg, relativeDir)
|
||||
if err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
|
||||
}
|
||||
ref.Args[i] = expanded
|
||||
}
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Simple format: name command args...
|
||||
parts := strings.Fields(args)
|
||||
if len(parts) < 2 {
|
||||
return api.MCPRef{}, errors.New("MCP requires at least name and command")
|
||||
}
|
||||
|
||||
ref := api.MCPRef{
|
||||
Name: parts[0],
|
||||
Command: parts[1],
|
||||
Type: "stdio",
|
||||
}
|
||||
if len(parts) > 2 {
|
||||
ref.Args = parts[2:]
|
||||
}
|
||||
|
||||
// Expand relative paths in args
|
||||
for i, arg := range ref.Args {
|
||||
if isLocalPath(arg) {
|
||||
expanded, err := expandPath(arg, relativeDir)
|
||||
if err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
|
||||
}
|
||||
ref.Args[i] = expanded
|
||||
}
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// isLocalPath checks if a string looks like a local filesystem path.
|
||||
func isLocalPath(s string) bool {
|
||||
return strings.HasPrefix(s, "/") ||
|
||||
strings.HasPrefix(s, "./") ||
|
||||
strings.HasPrefix(s, "../") ||
|
||||
strings.HasPrefix(s, "~")
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StepBar displays step-based progress (e.g., for image generation steps).
|
||||
type StepBar struct {
|
||||
message string
|
||||
current int
|
||||
total int
|
||||
}
|
||||
|
||||
func NewStepBar(message string, total int) *StepBar {
|
||||
return &StepBar{message: message, total: total}
|
||||
}
|
||||
|
||||
func (s *StepBar) Set(current int) {
|
||||
s.current = current
|
||||
}
|
||||
|
||||
func (s *StepBar) String() string {
|
||||
percent := float64(s.current) / float64(s.total) * 100
|
||||
barWidth := s.total
|
||||
empty := barWidth - s.current
|
||||
|
||||
// "Generating 0% ▕ ▏ 0/9"
|
||||
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
|
||||
s.message, percent,
|
||||
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
|
||||
s.current, s.total)
|
||||
}
|
||||
@@ -18,7 +18,6 @@ const (
|
||||
CharCtrlL = 12
|
||||
CharEnter = 13
|
||||
CharNext = 14
|
||||
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
|
||||
CharPrev = 16
|
||||
CharBckSearch = 18
|
||||
CharFwdSearch = 19
|
||||
|
||||
@@ -3,7 +3,6 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -12,19 +11,12 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
if args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
|
||||
_build_darwin() {
|
||||
for ARCH in $ARCHS; do
|
||||
status "Building darwin $ARCH"
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
|
||||
if [ "$ARCH" = "amd64" ]; then
|
||||
status "Building darwin $ARCH dynamic backends"
|
||||
BUILD_DIR=build/darwin-$ARCH
|
||||
cmake -B $BUILD_DIR \
|
||||
cmake -B build/darwin-$ARCH \
|
||||
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
|
||||
-DMLX_ENGINE=ON \
|
||||
-DMLX_ENABLE_X64_MAC=ON \
|
||||
-DOLLAMA_RUNNER_DIR=./
|
||||
cmake --build $BUILD_DIR --target ggml-cpu -j
|
||||
cmake --build $BUILD_DIR --target mlx mlxc -j
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
-DOLLAMA_RUNNER_DIR=./ \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
cmake --build build/darwin-$ARCH --target ggml-cpu -j
|
||||
cmake --install build/darwin-$ARCH --component CPU
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,12 +61,10 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,23 +131,17 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -178,7 +149,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -12,17 +12,6 @@ set -eu
|
||||
|
||||
. $(dirname $0)/env.sh
|
||||
|
||||
# Check for required tools
|
||||
if ! command -v zstd >/dev/null 2>&1; then
|
||||
echo "ERROR: zstd is required but not installed." >&2
|
||||
echo "Please install zstd:" >&2
|
||||
echo " - macOS: brew install zstd" >&2
|
||||
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
|
||||
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
|
||||
echo " - Arch: sudo pacman -S zstd" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
docker buildx build \
|
||||
@@ -48,68 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
echo "Compressing linux tar bundles..."
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
fi
|
||||
|
||||
@@ -66,36 +66,6 @@ if [ -n "$NEEDS" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Function to download and extract with fallback from zst to tgz
|
||||
download_and_extract() {
|
||||
local url_base="$1"
|
||||
local dest_dir="$2"
|
||||
local filename="$3"
|
||||
|
||||
# Check if .tar.zst is available
|
||||
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
|
||||
# zst file exists - check if we have zstd tool
|
||||
if ! available zstd; then
|
||||
error "This version requires zstd for extraction. Please install zstd and try again:
|
||||
- Debian/Ubuntu: sudo apt-get install zstd
|
||||
- RHEL/CentOS/Fedora: sudo dnf install zstd
|
||||
- Arch: sudo pacman -S zstd"
|
||||
fi
|
||||
|
||||
status "Downloading ${filename}.tar.zst"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
|
||||
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Fall back to .tgz for older versions
|
||||
status "Downloading ${filename}.tgz"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "${dest_dir}"
|
||||
}
|
||||
|
||||
for BINDIR in /usr/local/bin /usr/bin /bin; do
|
||||
echo $PATH | grep -q $BINDIR && break || continue
|
||||
done
|
||||
@@ -108,7 +78,10 @@ fi
|
||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
|
||||
status "Downloading Linux ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||
status "Making ollama accessible in the PATH in $BINDIR"
|
||||
@@ -118,9 +91,15 @@ fi
|
||||
# Check for NVIDIA JetPack systems with additional downloads
|
||||
if [ -f /etc/nv_tegra_release ] ; then
|
||||
if grep R36 /etc/nv_tegra_release > /dev/null ; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
|
||||
status "Downloading JetPack 6 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
|
||||
status "Downloading JetPack 5 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
else
|
||||
warning "Unsupported JetPack version detected. GPU may not be supported"
|
||||
fi
|
||||
@@ -243,7 +222,10 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
|
||||
fi
|
||||
|
||||
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
|
||||
status "Downloading Linux ROCm ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
install_success
|
||||
status "AMD GPU ready."
|
||||
|
||||
179
server/create.go
179
server/create.go
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
@@ -63,6 +62,10 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
config.Skills = r.Skills
|
||||
config.MCPs = r.MCPs
|
||||
config.AgentType = r.AgentType
|
||||
config.Entrypoint = r.Entrypoint
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
@@ -122,7 +125,10 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
// Inherit config from base model (Renderer, Parser, Requires, Capabilities, etc.)
|
||||
// This is especially important for cloud models which don't have GGUF files
|
||||
// to detect capabilities from.
|
||||
if err == nil && !remote {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
@@ -139,6 +145,29 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
// Inherit capabilities for cloud/remote models
|
||||
// (local models detect capabilities from GGUF file)
|
||||
if len(config.Capabilities) == 0 && len(baseConfig.Capabilities) > 0 {
|
||||
config.Capabilities = baseConfig.Capabilities
|
||||
}
|
||||
// Inherit remote host/model if base is a cloud model
|
||||
if config.RemoteHost == "" && baseConfig.RemoteHost != "" {
|
||||
config.RemoteHost = baseConfig.RemoteHost
|
||||
}
|
||||
if config.RemoteModel == "" && baseConfig.RemoteModel != "" {
|
||||
config.RemoteModel = baseConfig.RemoteModel
|
||||
}
|
||||
// Inherit model family for proper rendering
|
||||
if config.ModelFamily == "" && baseConfig.ModelFamily != "" {
|
||||
config.ModelFamily = baseConfig.ModelFamily
|
||||
}
|
||||
if len(config.ModelFamilies) == 0 && len(baseConfig.ModelFamilies) > 0 {
|
||||
config.ModelFamilies = baseConfig.ModelFamilies
|
||||
}
|
||||
// Inherit context length for cloud models
|
||||
if config.ContextLen == 0 && baseConfig.ContextLen > 0 {
|
||||
config.ContextLen = baseConfig.ContextLen
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
@@ -158,6 +187,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
} else if r.Entrypoint != "" {
|
||||
// Entrypoint-only agent: no base model needed
|
||||
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
|
||||
} else {
|
||||
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -455,7 +487,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
@@ -544,6 +576,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle skill layers for agents
|
||||
layers, config.Skills, err = setSkillLayers(layers, config.Skills, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle MCP layers for agents
|
||||
layers, config.MCPs, err = setMCPLayers(layers, config.MCPs, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(layers, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -794,6 +838,135 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// setSkillLayers creates skill layers for local skill paths and updates the skill refs.
|
||||
// Local paths are converted to bundled skill layers with digests.
|
||||
// Registry references are kept as-is for later resolution during pull.
|
||||
func setSkillLayers(layers []Layer, skills []model.SkillRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.SkillRef, error) {
|
||||
if len(skills) == 0 {
|
||||
return layers, skills, nil
|
||||
}
|
||||
|
||||
// Remove any existing skill layers
|
||||
layers = removeLayer(layers, MediaTypeSkill)
|
||||
|
||||
var updatedSkills []model.SkillRef
|
||||
|
||||
for _, skill := range skills {
|
||||
// Check if this is a local path
|
||||
if IsLocalSkillPath(skill.Name) {
|
||||
// Expand home directory if needed
|
||||
skillPath := skill.Name
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
// Make absolute
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving skill path %q: %w", skill.Name, err)
|
||||
}
|
||||
|
||||
// Check if this is a direct skill directory or a parent containing skills
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err == nil {
|
||||
// Direct skill directory
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", filepath.Base(absPath))})
|
||||
|
||||
layer, err := CreateSkillLayer(absPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating skill layer for %q: %w", skill.Name, err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
updatedSkills = append(updatedSkills, model.SkillRef{
|
||||
Name: filepath.Base(absPath),
|
||||
Digest: layer.Digest,
|
||||
})
|
||||
} else {
|
||||
// Parent directory - walk to find skill subdirectories
|
||||
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != "SKILL.md" {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillDir := filepath.Dir(path)
|
||||
skillName := filepath.Base(skillDir)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", skillName)})
|
||||
|
||||
layer, err := CreateSkillLayer(skillDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill layer for %q: %w", skillDir, err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
updatedSkills = append(updatedSkills, model.SkillRef{
|
||||
Name: skillName,
|
||||
Digest: layer.Digest,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("walking skill directory %q: %w", skill.Name, err)
|
||||
}
|
||||
}
|
||||
} else if skill.Digest != "" {
|
||||
// Already has a digest (from a pulled agent), keep as-is
|
||||
updatedSkills = append(updatedSkills, skill)
|
||||
} else {
|
||||
// Registry reference - keep as-is for later resolution
|
||||
updatedSkills = append(updatedSkills, skill)
|
||||
}
|
||||
}
|
||||
|
||||
return layers, updatedSkills, nil
|
||||
}
|
||||
|
||||
// setMCPLayers handles MCP server references.
|
||||
// Currently, MCPs are stored as config data (command/args).
|
||||
// Future: support bundling MCP server directories as layers.
|
||||
func setMCPLayers(layers []Layer, mcps []model.MCPRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.MCPRef, error) {
|
||||
if len(mcps) == 0 {
|
||||
return layers, mcps, nil
|
||||
}
|
||||
|
||||
// Remove any existing MCP layers
|
||||
layers = removeLayer(layers, MediaTypeMCP)
|
||||
|
||||
var updatedMCPs []model.MCPRef
|
||||
|
||||
for _, mcp := range mcps {
|
||||
// Validate MCP has required fields
|
||||
if mcp.Name == "" {
|
||||
return nil, nil, fmt.Errorf("MCP server requires a name")
|
||||
}
|
||||
if mcp.Command == "" {
|
||||
return nil, nil, fmt.Errorf("MCP server %q requires a command", mcp.Name)
|
||||
}
|
||||
|
||||
// Set default type if not specified
|
||||
if mcp.Type == "" {
|
||||
mcp.Type = "stdio"
|
||||
}
|
||||
|
||||
// For now, just keep MCPs as config data
|
||||
// Future: detect local paths in args and bundle them
|
||||
updatedMCPs = append(updatedMCPs, mcp)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("configuring MCP: %s", mcp.Name)})
|
||||
}
|
||||
|
||||
return layers, updatedMCPs, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
|
||||
196
server/images.go
196
server/images.go
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -74,11 +73,6 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
@@ -238,6 +232,13 @@ func (m *Model) String() string {
|
||||
})
|
||||
}
|
||||
|
||||
if m.Config.Entrypoint != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "entrypoint",
|
||||
Args: m.Config.Entrypoint,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
@@ -561,24 +562,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestJSON, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
@@ -644,15 +627,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
@@ -667,6 +641,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
@@ -675,11 +650,13 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
@@ -687,10 +664,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
// Extract skill layers to the skills cache
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == MediaTypeSkill {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("extracting skill %s", layer.Digest)})
|
||||
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
|
||||
return fmt.Errorf("extracting skill layer %s: %w", layer.Digest, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
@@ -725,148 +707,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pulling model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
From: layer.From,
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pushing model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
SrcDir: srcDir,
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
|
||||
@@ -47,15 +47,6 @@ func TestModelCapabilities(t *testing.T) {
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with image generation capability via config",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -13,14 +13,9 @@ type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
}
|
||||
|
||||
const (
|
||||
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
|
||||
@@ -129,11 +129,30 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
// Find both 4-part (models) and 5-part (skills/agents) manifest paths
|
||||
matches4, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matches5, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Combine matches, filtering to only include files
|
||||
var matches []string
|
||||
for _, match := range matches4 {
|
||||
fi, err := os.Stat(match)
|
||||
if err == nil && !fi.IsDir() {
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
for _, match := range matches5 {
|
||||
fi, err := os.Stat(match)
|
||||
if err == nil && !fi.IsDir() {
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
|
||||
ms := make(map[model.Name]*Manifest)
|
||||
for _, match := range matches {
|
||||
|
||||
315
server/mcp.go
Normal file
315
server/mcp.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MediaTypeMCP is the media type for MCP server layers in manifests.
|
||||
const MediaTypeMCP = "application/vnd.ollama.image.mcp"
|
||||
|
||||
// GetMCPsPath returns the path to the extracted MCPs cache directory.
|
||||
// If digest is empty, returns the mcps directory itself.
|
||||
// If digest is provided, returns the path to the extracted MCP for that digest.
|
||||
func GetMCPsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "mcps", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ExtractMCPBlob extracts an MCP tar.gz blob to the mcps cache.
|
||||
// The blob is expected to be at the blobs path for the given digest.
|
||||
// Returns the path to the extracted MCP directory.
|
||||
func ExtractMCPBlob(digest string) (string, error) {
|
||||
// Get the blob path
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting blob path: %w", err)
|
||||
}
|
||||
|
||||
// Get the extraction path
|
||||
mcpPath, err := GetMCPsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting mcp path: %w", err)
|
||||
}
|
||||
|
||||
// Check if already extracted (look for any file)
|
||||
entries, err := os.ReadDir(mcpPath)
|
||||
if err == nil && len(entries) > 0 {
|
||||
return mcpPath, nil
|
||||
}
|
||||
|
||||
// Open the blob
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("opening blob: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create gzip reader
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
// Create tar reader
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
// Create the mcp directory
|
||||
if err := os.MkdirAll(mcpPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating mcp directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract files
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading tar: %w", err)
|
||||
}
|
||||
|
||||
// Clean the name and ensure it doesn't escape the target directory
|
||||
name := filepath.Clean(header.Name)
|
||||
if strings.HasPrefix(name, "..") {
|
||||
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
|
||||
}
|
||||
|
||||
target := filepath.Join(mcpPath, name)
|
||||
|
||||
// Verify the target is within mcpPath
|
||||
if !strings.HasPrefix(target, filepath.Clean(mcpPath)+string(os.PathSeparator)) && target != filepath.Clean(mcpPath) {
|
||||
return "", fmt.Errorf("path escapes mcp directory: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(target, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating parent directory: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
outFile.Close()
|
||||
return "", fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return mcpPath, nil
|
||||
}
|
||||
|
||||
// CreateMCPLayer creates an MCP layer from a local directory.
|
||||
// The directory can optionally contain an mcp.json or package.json file.
|
||||
// Returns the created layer.
|
||||
func CreateMCPLayer(mcpDir string) (Layer, error) {
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(mcpDir)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("mcp directory not found: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return Layer{}, fmt.Errorf("mcp path is not a directory: %s", mcpDir)
|
||||
}
|
||||
|
||||
// Create a temporary file for the tar.gz
|
||||
blobsPath, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(blobsPath, "mcp-*.tar.gz")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
}()
|
||||
|
||||
// Create gzip writer
|
||||
gzw := gzip.NewWriter(tmpFile)
|
||||
defer gzw.Close()
|
||||
|
||||
// Create tar writer
|
||||
tw := tar.NewWriter(gzw)
|
||||
defer tw.Close()
|
||||
|
||||
// Walk the mcp directory and add files to tar
|
||||
err = filepath.Walk(mcpDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(mcpDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = relPath
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write file contents if it's a regular file
|
||||
if !info.IsDir() {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
|
||||
}
|
||||
|
||||
// Close writers to flush
|
||||
if err := tw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing temp file: %w", err)
|
||||
}
|
||||
|
||||
// Open the temp file for reading
|
||||
tmpFile, err = os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
// Create the layer (this will compute the digest and move to blobs)
|
||||
layer, err := NewLayer(tmpFile, MediaTypeMCP)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating layer: %w", err)
|
||||
}
|
||||
|
||||
// Extract the mcp to the cache so it's ready to use
|
||||
if _, err := ExtractMCPBlob(layer.Digest); err != nil {
|
||||
return Layer{}, fmt.Errorf("extracting mcp: %w", err)
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
// IsLocalMCPPath checks if an MCP reference looks like a local path.
|
||||
// Local paths are explicitly prefixed with /, ./, ../, or ~.
|
||||
func IsLocalMCPPath(name string) bool {
|
||||
return strings.HasPrefix(name, "/") ||
|
||||
strings.HasPrefix(name, "./") ||
|
||||
strings.HasPrefix(name, "../") ||
|
||||
strings.HasPrefix(name, "~")
|
||||
}
|
||||
|
||||
// MCPNamespace is the namespace used for standalone MCPs in the registry.
|
||||
const MCPNamespace = "mcp"
|
||||
|
||||
// IsMCPReference checks if a name refers to an MCP (has mcp/ prefix).
|
||||
func IsMCPReference(name string) bool {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
|
||||
// mcp/name or mcp/name:tag
|
||||
if len(parts) >= 1 && parts[0] == MCPNamespace {
|
||||
return true
|
||||
}
|
||||
// namespace/mcp/name (e.g., myuser/mcp/websearch)
|
||||
if len(parts) >= 2 && parts[1] == MCPNamespace {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseMCPName parses an MCP reference string into a model.Name.
|
||||
// The Kind field is set to "mcp".
|
||||
func ParseMCPName(name string) model.Name {
|
||||
n := model.ParseName(name)
|
||||
|
||||
// If Kind wasn't set (old format without mcp/), set it
|
||||
if n.Kind == "" {
|
||||
n.Kind = MCPNamespace
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// GetMCPManifestPath returns the path to the MCP manifest file.
|
||||
func GetMCPManifestPath(n model.Name) (string, error) {
|
||||
if n.Model == "" {
|
||||
return "", fmt.Errorf("mcp name is required")
|
||||
}
|
||||
|
||||
// Ensure Kind is set
|
||||
if n.Kind == "" {
|
||||
n.Kind = MCPNamespace
|
||||
}
|
||||
|
||||
path := filepath.Join(
|
||||
envconfig.Models(),
|
||||
"manifests",
|
||||
n.Filepath(),
|
||||
)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -18,6 +18,7 @@ type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Kind string // Optional: "skill", "agent", or empty for models
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
@@ -42,6 +43,7 @@ func ParseModelPath(name string) ModelPath {
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Kind: "",
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
@@ -55,13 +57,41 @@ func ParseModelPath(name string) ModelPath {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
case 4:
|
||||
// host/namespace/kind/model or host/namespace/model:tag with kind
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
if model.ValidKinds[parts[2]] {
|
||||
mp.Kind = parts[2]
|
||||
mp.Repository = parts[3]
|
||||
} else {
|
||||
// Not a valid kind, treat as old format with extra part
|
||||
mp.Repository = parts[3]
|
||||
}
|
||||
case 3:
|
||||
// Could be: host/namespace/model OR namespace/kind/model
|
||||
if model.ValidKinds[parts[1]] {
|
||||
// namespace/kind/model
|
||||
mp.Namespace = parts[0]
|
||||
mp.Kind = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
} else {
|
||||
// host/namespace/model
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
}
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
// Could be: namespace/model OR kind/model
|
||||
if model.ValidKinds[parts[0]] {
|
||||
// kind/model (library skill)
|
||||
mp.Kind = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
} else {
|
||||
// namespace/model
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
}
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
@@ -75,20 +105,35 @@ func ParseModelPath(name string) ModelPath {
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", mp.Namespace, mp.Kind, mp.Repository)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
@@ -97,6 +142,7 @@ func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Kind: mp.Kind,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
|
||||
@@ -50,8 +50,6 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +162,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +189,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1009,6 +978,9 @@ func getExistingName(n model.Name) (model.Name, error) {
|
||||
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
|
||||
n.Namespace = e.Namespace
|
||||
}
|
||||
if set.Kind == "" && strings.EqualFold(e.Kind, n.Kind) {
|
||||
n.Kind = e.Kind
|
||||
}
|
||||
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
|
||||
n.Model = e.Model
|
||||
}
|
||||
@@ -1147,6 +1119,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
Skills: m.Config.Skills,
|
||||
MCPs: m.Config.MCPs,
|
||||
AgentType: m.Config.AgentType,
|
||||
Entrypoint: m.Config.Entrypoint,
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
@@ -1201,11 +1177,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
fmt.Fprint(&sb, m.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
// skip loading tensor information if this is a remote model
|
||||
// skip loading tensor information if this is a remote model or a skill
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Skills don't have model weights, skip tensor loading
|
||||
if m.ModelPath == "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1575,12 +1556,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -42,7 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -195,14 +194,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -552,48 +543,6 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
|
||||
if len(allGpus) == 0 {
|
||||
return
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -806,61 +804,3 @@ func (s *mockLlm) GetPort() int { return -
|
||||
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Simulate an image gen runner already loaded
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
|
||||
modelPath: "/fake/image/model",
|
||||
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
refCount: 0, // idle
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/image/model"] = imageGenRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify the image gen runner is loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// findRunnerToUnload should find the idle image gen runner
|
||||
runner := s.findRunnerToUnload()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
326
server/skill.go
Normal file
326
server/skill.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MediaTypeSkill is the media type for skill layers in manifests.
|
||||
const MediaTypeSkill = "application/vnd.ollama.image.skill"
|
||||
|
||||
// GetSkillsPath returns the path to the extracted skills cache directory.
|
||||
// If digest is empty, returns the skills directory itself.
|
||||
// If digest is provided, returns the path to the extracted skill for that digest.
|
||||
func GetSkillsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "skills", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ExtractSkillBlob extracts a skill tar.gz blob to the skills cache.
|
||||
// The blob is expected to be at the blobs path for the given digest.
|
||||
// Returns the path to the extracted skill directory.
|
||||
func ExtractSkillBlob(digest string) (string, error) {
|
||||
// Get the blob path
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting blob path: %w", err)
|
||||
}
|
||||
|
||||
// Get the extraction path
|
||||
skillPath, err := GetSkillsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting skill path: %w", err)
|
||||
}
|
||||
|
||||
// Check if already extracted
|
||||
if _, err := os.Stat(filepath.Join(skillPath, "SKILL.md")); err == nil {
|
||||
return skillPath, nil
|
||||
}
|
||||
|
||||
// Open the blob
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("opening blob: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create gzip reader
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
// Create tar reader
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
// Create the skill directory
|
||||
if err := os.MkdirAll(skillPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating skill directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract files
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading tar: %w", err)
|
||||
}
|
||||
|
||||
// Clean the name and ensure it doesn't escape the target directory
|
||||
name := filepath.Clean(header.Name)
|
||||
if strings.HasPrefix(name, "..") {
|
||||
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
|
||||
}
|
||||
|
||||
target := filepath.Join(skillPath, name)
|
||||
|
||||
// Verify the target is within skillPath
|
||||
if !strings.HasPrefix(target, filepath.Clean(skillPath)+string(os.PathSeparator)) && target != filepath.Clean(skillPath) {
|
||||
return "", fmt.Errorf("path escapes skill directory: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(target, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating parent directory: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
outFile.Close()
|
||||
return "", fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return skillPath, nil
|
||||
}
|
||||
|
||||
// CreateSkillLayer creates a skill layer from a local directory.
|
||||
// The directory must contain a SKILL.md file.
|
||||
// Returns the created layer.
|
||||
func CreateSkillLayer(skillDir string) (Layer, error) {
|
||||
// Verify SKILL.md exists
|
||||
skillMdPath := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
return Layer{}, fmt.Errorf("skill directory must contain SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Create a temporary file for the tar.gz
|
||||
blobsPath, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(blobsPath, "skill-*.tar.gz")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
}()
|
||||
|
||||
// Create gzip writer
|
||||
gzw := gzip.NewWriter(tmpFile)
|
||||
defer gzw.Close()
|
||||
|
||||
// Create tar writer
|
||||
tw := tar.NewWriter(gzw)
|
||||
defer tw.Close()
|
||||
|
||||
// Walk the skill directory and add files to tar
|
||||
err = filepath.Walk(skillDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(skillDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = relPath
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write file contents if it's a regular file
|
||||
if !info.IsDir() {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
|
||||
}
|
||||
|
||||
// Close writers to flush
|
||||
if err := tw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing temp file: %w", err)
|
||||
}
|
||||
|
||||
// Open the temp file for reading
|
||||
tmpFile, err = os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
// Create the layer (this will compute the digest and move to blobs)
|
||||
layer, err := NewLayer(tmpFile, MediaTypeSkill)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating layer: %w", err)
|
||||
}
|
||||
|
||||
// Extract the skill to the cache so it's ready to use
|
||||
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
|
||||
return Layer{}, fmt.Errorf("extracting skill: %w", err)
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
// IsLocalSkillPath checks if a skill reference looks like a local path.
|
||||
// Local paths are explicitly prefixed with /, ./, ../, or ~.
|
||||
// Registry references like "skill/calculator:1.0.0" should NOT be treated as local paths.
|
||||
func IsLocalSkillPath(name string) bool {
|
||||
// Local paths are explicitly indicated by path prefixes
|
||||
return strings.HasPrefix(name, "/") ||
|
||||
strings.HasPrefix(name, "./") ||
|
||||
strings.HasPrefix(name, "../") ||
|
||||
strings.HasPrefix(name, "~")
|
||||
}
|
||||
|
||||
// SkillNamespace is the namespace used for standalone skills in the registry.
|
||||
const SkillNamespace = "skill"
|
||||
|
||||
// IsSkillReference checks if a name refers to a skill (has skill/ prefix).
|
||||
func IsSkillReference(name string) bool {
|
||||
// Check for skill/ prefix (handles both "skill/foo" and "registry/skill/foo")
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
|
||||
// skill/name or skill/name:tag
|
||||
if len(parts) >= 1 && parts[0] == SkillNamespace {
|
||||
return true
|
||||
}
|
||||
// namespace/skill/name (e.g., myuser/skill/calc) - not a skill ref
|
||||
// registry/skill/name (e.g., registry.ollama.ai/skill/calc)
|
||||
if len(parts) >= 2 && parts[1] == SkillNamespace {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseSkillName parses a skill reference string into a model.Name.
|
||||
// The Kind field is set to "skill".
|
||||
// Examples:
|
||||
// - "calculator" -> library/skill/calculator:latest
|
||||
// - "myname/calculator" -> myname/skill/calculator:latest
|
||||
// - "myname/skill/calculator:1.0.0" -> myname/skill/calculator:1.0.0
|
||||
func ParseSkillName(name string) model.Name {
|
||||
// Use the standard parser which now handles Kind
|
||||
n := model.ParseName(name)
|
||||
|
||||
// If Kind wasn't set (old format without skill/), set it
|
||||
if n.Kind == "" {
|
||||
n.Kind = SkillNamespace
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// SkillDisplayName returns a user-friendly display name for a skill.
|
||||
func SkillDisplayName(n model.Name) string {
|
||||
return n.DisplayShortest()
|
||||
}
|
||||
|
||||
// GetSkillManifestPath returns the path to the skill manifest file.
|
||||
// Uses the 5-part structure: host/namespace/kind/model/tag
|
||||
func GetSkillManifestPath(n model.Name) (string, error) {
|
||||
if n.Model == "" {
|
||||
return "", fmt.Errorf("skill name is required")
|
||||
}
|
||||
|
||||
// Ensure Kind is set
|
||||
if n.Kind == "" {
|
||||
n.Kind = SkillNamespace
|
||||
}
|
||||
|
||||
path := filepath.Join(
|
||||
envconfig.Models(),
|
||||
"manifests",
|
||||
n.Filepath(),
|
||||
)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -381,28 +381,6 @@ func (t templateTools) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateArgs is a map type with JSON string output for templates.
|
||||
type templateArgs map[string]any
|
||||
|
||||
func (t templateArgs) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateProperties is a map type with JSON string output for templates.
|
||||
type templateProperties map[string]api.ToolProperty
|
||||
|
||||
func (t templateProperties) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
@@ -418,11 +396,11 @@ type templateToolFunction struct {
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties templateProperties `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
@@ -435,7 +413,7 @@ type templateToolCall struct {
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments templateArgs
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "Tokyo")
|
||||
args.Set("unit", "celsius")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Arguments output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:{...}]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Properties output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsRange(t *testing.T) {
|
||||
// Test that we can range over Arguments in templates
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("city", "Tokyo")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "city=Tokyo;" {
|
||||
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesRange(t *testing.T) {
|
||||
// Test that we can range over Properties in templates
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "location:string;" {
|
||||
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,13 +3,12 @@ package model
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
@@ -1,5 +1,29 @@
|
||||
package model
|
||||
|
||||
// SkillRef represents a reference to a skill, either by local path or by registry digest.
|
||||
type SkillRef struct {
|
||||
// Name is the local path (for development) or registry name (e.g., "skill/calculator:1.0.0")
|
||||
Name string `json:"name,omitempty"`
|
||||
// Digest is the content-addressable digest of the skill blob (e.g., "sha256:abc123...")
|
||||
Digest string `json:"digest,omitempty"`
|
||||
}
|
||||
|
||||
// MCPRef represents a reference to an MCP (Model Context Protocol) server.
|
||||
type MCPRef struct {
|
||||
// Name is the identifier for the MCP server (used for tool namespacing)
|
||||
Name string `json:"name,omitempty"`
|
||||
// Digest is the content-addressable digest of the bundled MCP server blob
|
||||
Digest string `json:"digest,omitempty"`
|
||||
// Command is the executable to run (e.g., "uv", "node", "python3")
|
||||
Command string `json:"command,omitempty"`
|
||||
// Args are the arguments to pass to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Env is optional environment variables for the MCP server
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// Type is the transport type (currently only "stdio" is supported)
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// ConfigV2 represents the configuration metadata for a model.
|
||||
type ConfigV2 struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
@@ -20,6 +44,12 @@ type ConfigV2 struct {
|
||||
EmbedLen int `json:"embedding_length,omitempty"`
|
||||
BaseName string `json:"base_name,omitempty"`
|
||||
|
||||
// agent-specific fields
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
OS string `json:"os"`
|
||||
|
||||
@@ -59,6 +59,7 @@ type partKind int
|
||||
const (
|
||||
kindHost partKind = iota
|
||||
kindNamespace
|
||||
kindKind
|
||||
kindModel
|
||||
kindTag
|
||||
kindDigest
|
||||
@@ -70,6 +71,8 @@ func (k partKind) String() string {
|
||||
return "host"
|
||||
case kindNamespace:
|
||||
return "namespace"
|
||||
case kindKind:
|
||||
return "kind"
|
||||
case kindModel:
|
||||
return "model"
|
||||
case kindTag:
|
||||
@@ -89,6 +92,7 @@ func (k partKind) String() string {
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Kind string // Optional: "skill", "agent", or empty for models
|
||||
Model string
|
||||
Tag string
|
||||
}
|
||||
@@ -97,34 +101,27 @@ type Name struct {
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { kind } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { kind } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// kind:
|
||||
// pattern: "skill" | "agent" | "" (empty for models)
|
||||
// length: [0, 80]
|
||||
// model:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// Most users should use [ParseName] instead, unless need to support
|
||||
// different defaults than DefaultName.
|
||||
@@ -136,6 +133,13 @@ func ParseName(s string) Name {
|
||||
return Merge(ParseNameBare(s), DefaultName())
|
||||
}
|
||||
|
||||
// ValidKinds are the allowed values for the Kind field
|
||||
var ValidKinds = map[string]bool{
|
||||
"skill": true,
|
||||
"agent": true,
|
||||
"mcp": true,
|
||||
}
|
||||
|
||||
// ParseNameBare parses s as a name string and returns a Name. No merge with
|
||||
// [DefaultName] is performed.
|
||||
func ParseNameBare(s string) Name {
|
||||
@@ -153,6 +157,30 @@ func ParseNameBare(s string) Name {
|
||||
return n
|
||||
}
|
||||
|
||||
s, n.Kind, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
// Only 2 parts: namespace/model - what we parsed as Kind is actually Namespace
|
||||
n.Namespace = n.Kind
|
||||
n.Kind = ""
|
||||
return n
|
||||
}
|
||||
|
||||
// Check if what we parsed as Kind is actually a valid kind value
|
||||
if !ValidKinds[n.Kind] {
|
||||
// Not a valid kind - this is the old 3-part format: host/namespace/model
|
||||
// Shift: Kind -> Namespace, s -> Host
|
||||
n.Namespace = n.Kind
|
||||
n.Kind = ""
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
return n
|
||||
}
|
||||
|
||||
// Valid kind found - continue parsing for namespace and optional host
|
||||
s, n.Namespace, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Namespace = s
|
||||
@@ -168,20 +196,32 @@ func ParseNameBare(s string) Name {
|
||||
return n
|
||||
}
|
||||
|
||||
// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
|
||||
// ParseNameFromFilepath parses a 4 or 5-part filepath as a Name. The parts are
|
||||
// expected to be in the form:
|
||||
//
|
||||
// { host } "/" { namespace } "/" { model } "/" { tag }
|
||||
// { host } "/" { namespace } "/" { kind } "/" { model } "/" { tag }
|
||||
func ParseNameFromFilepath(s string) (n Name) {
|
||||
parts := strings.Split(s, string(filepath.Separator))
|
||||
if len(parts) != 4 {
|
||||
|
||||
switch len(parts) {
|
||||
case 4:
|
||||
// Old format: host/namespace/model/tag
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Model = parts[2]
|
||||
n.Tag = parts[3]
|
||||
case 5:
|
||||
// New format: host/namespace/kind/model/tag
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Kind = parts[2]
|
||||
n.Model = parts[3]
|
||||
n.Tag = parts[4]
|
||||
default:
|
||||
return Name{}
|
||||
}
|
||||
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Model = parts[2]
|
||||
n.Tag = parts[3]
|
||||
if !n.IsFullyQualified() {
|
||||
return Name{}
|
||||
}
|
||||
@@ -189,11 +229,12 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, kind, and tag parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Kind = cmp.Or(a.Kind, b.Kind)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
return a
|
||||
}
|
||||
@@ -211,6 +252,10 @@ func (n Name) String() string {
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
if n.Kind != "" {
|
||||
b.WriteString(n.Kind)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
b.WriteString(n.Model)
|
||||
if n.Tag != "" {
|
||||
b.WriteByte(':')
|
||||
@@ -233,6 +278,12 @@ func (n Name) DisplayShortest() string {
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
||||
// include kind if present
|
||||
if n.Kind != "" {
|
||||
sb.WriteString(n.Kind)
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
||||
// always include model and tag
|
||||
sb.WriteString(n.Model)
|
||||
sb.WriteString(":")
|
||||
@@ -256,18 +307,23 @@ func (n Name) IsValid() bool {
|
||||
}
|
||||
|
||||
// IsFullyQualified returns true if all parts of the name are present and
|
||||
// valid without the digest.
|
||||
// valid without the digest. Kind is optional and only validated if non-empty.
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
parts := []string{
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
if !isValidPart(kindHost, n.Host) {
|
||||
return false
|
||||
}
|
||||
for i, part := range parts {
|
||||
if !isValidPart(partKind(i), part) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindNamespace, n.Namespace) {
|
||||
return false
|
||||
}
|
||||
// Kind is optional - only validate if present
|
||||
if n.Kind != "" && !isValidPart(kindKind, n.Kind) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindModel, n.Model) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindTag, n.Tag) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -276,6 +332,7 @@ func (n Name) IsFullyQualified() bool {
|
||||
// host to tag as a directory in the form:
|
||||
//
|
||||
// {host}/{namespace}/{model}/{tag}
|
||||
// {host}/{namespace}/{kind}/{model}/{tag}
|
||||
//
|
||||
// It uses the system's filepath separator and ensures the path is clean.
|
||||
//
|
||||
@@ -285,6 +342,15 @@ func (n Name) Filepath() string {
|
||||
if !n.IsFullyQualified() {
|
||||
panic("illegal attempt to get filepath of invalid name")
|
||||
}
|
||||
if n.Kind != "" {
|
||||
return filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Kind,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
)
|
||||
}
|
||||
return filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
@@ -301,6 +367,7 @@ func (n Name) LogValue() slog.Value {
|
||||
func (n Name) EqualFold(o Name) bool {
|
||||
return strings.EqualFold(n.Host, o.Host) &&
|
||||
strings.EqualFold(n.Namespace, o.Namespace) &&
|
||||
strings.EqualFold(n.Kind, o.Kind) &&
|
||||
strings.EqualFold(n.Model, o.Model) &&
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
@@ -317,6 +384,11 @@ func isValidLen(kind partKind, s string) bool {
|
||||
}
|
||||
|
||||
func isValidPart(kind partKind, s string) bool {
|
||||
// Kind must be one of the valid values
|
||||
if kind == kindKind {
|
||||
return ValidKinds[s]
|
||||
}
|
||||
|
||||
if !isValidLen(kind, s) {
|
||||
return false
|
||||
}
|
||||
|
||||
24
x/README.md
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -4,7 +4,6 @@ package agent
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -33,29 +32,10 @@ type ApprovalResult struct {
|
||||
// Option labels for the selector (numbered for quick selection)
|
||||
var optionLabels = []string{
|
||||
"1. Execute once",
|
||||
"2. Allow for this session",
|
||||
"2. Always allow",
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
// toolDisplayNames maps internal tool names to human-readable display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
func ToolDisplayName(toolName string) string {
|
||||
if displayName, ok := toolDisplayNames[toolName]; ok {
|
||||
return displayName
|
||||
}
|
||||
// Default: capitalize first letter and replace underscores with spaces
|
||||
name := strings.ReplaceAll(toolName, "_", " ")
|
||||
if len(name) > 0 {
|
||||
return strings.ToUpper(name[:1]) + name[1:]
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// autoAllowCommands are commands that are always allowed without prompting.
|
||||
// These are zero-risk, read-only commands.
|
||||
var autoAllowCommands = map[string]bool{
|
||||
@@ -199,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For commands without path args, returns empty string.
|
||||
// Paths with ".." traversal that escape the base directory return empty string for security.
|
||||
func extractBashPrefix(command string) string {
|
||||
// Split command by pipes and get the first part
|
||||
parts := strings.Split(command, "|")
|
||||
@@ -225,8 +204,8 @@ func extractBashPrefix(command string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or \ or start with .)
|
||||
// First pass: look for clear paths (containing path separators or starting with .)
|
||||
// Find the first path-like argument (must contain / or start with .)
|
||||
// First pass: look for clear paths (containing / or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
@@ -236,49 +215,19 @@ func extractBashPrefix(command string) string {
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or \ or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
|
||||
// Only process if it looks like a path (contains / or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
}
|
||||
// Normalize to forward slashes for consistent cross-platform matching
|
||||
arg = strings.ReplaceAll(arg, "\\", "/")
|
||||
|
||||
// Security: reject absolute paths
|
||||
if path.IsAbs(arg) {
|
||||
return "" // Absolute path - don't create prefix
|
||||
// If arg ends with /, it's a directory - use it directly
|
||||
if strings.HasSuffix(arg, "/") {
|
||||
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||
}
|
||||
|
||||
// Normalize the path using stdlib path.Clean (resolves . and ..)
|
||||
cleaned := path.Clean(arg)
|
||||
|
||||
// Security: reject if cleaned path escapes to parent directory
|
||||
if strings.HasPrefix(cleaned, "..") {
|
||||
return "" // Path escapes - don't create prefix
|
||||
}
|
||||
|
||||
// Security: if original had "..", verify cleaned path didn't escape to sibling
|
||||
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
|
||||
if strings.Contains(arg, "..") {
|
||||
origBase := strings.SplitN(arg, "/", 2)[0]
|
||||
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
|
||||
if origBase != cleanedBase {
|
||||
return "" // Path escaped to sibling directory
|
||||
}
|
||||
}
|
||||
|
||||
// Check if arg ends with / (explicit directory)
|
||||
isDir := strings.HasSuffix(arg, "/")
|
||||
|
||||
// Get the directory part
|
||||
var dir string
|
||||
if isDir {
|
||||
dir = cleaned
|
||||
} else {
|
||||
dir = path.Dir(cleaned)
|
||||
}
|
||||
|
||||
// Get the directory part of a file path
|
||||
dir := filepath.Dir(arg)
|
||||
if dir == "." {
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||
}
|
||||
@@ -383,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
|
||||
}
|
||||
|
||||
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
|
||||
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
|
||||
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
@@ -395,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// For bash commands, check prefix matches with hierarchical path support
|
||||
// For bash commands, check prefix matches
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
// Check exact prefix match first
|
||||
if a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
// Check hierarchical match: if any stored prefix is a parent of current prefix
|
||||
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
|
||||
if a.matchesHierarchicalPrefix(prefix) {
|
||||
return true
|
||||
}
|
||||
if prefix != "" && a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -421,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
|
||||
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
|
||||
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
|
||||
// Split prefix into command and path parts (format: "cmd:path/")
|
||||
colonIdx := strings.Index(currentPrefix, ":")
|
||||
if colonIdx == -1 {
|
||||
return false
|
||||
}
|
||||
currentCmd := currentPrefix[:colonIdx]
|
||||
currentPath := currentPrefix[colonIdx+1:]
|
||||
|
||||
for storedPrefix := range a.prefixes {
|
||||
storedColonIdx := strings.Index(storedPrefix, ":")
|
||||
if storedColonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
storedCmd := storedPrefix[:storedColonIdx]
|
||||
storedPath := storedPrefix[storedColonIdx+1:]
|
||||
|
||||
// Commands must match exactly
|
||||
if currentCmd != storedCmd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if current path starts with stored path (hierarchical match)
|
||||
// e.g., "tools/subdir/" starts with "tools/"
|
||||
if strings.HasPrefix(currentPath, storedPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||
@@ -494,32 +399,16 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// This prevents buffered input from causing double-press issues
|
||||
flushStdin(fd)
|
||||
|
||||
// Check if bash command targets paths outside cwd
|
||||
isWarning := false
|
||||
var warningMsg string
|
||||
var allowlistInfo string
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
if isCommandOutsideCwd(cmd) {
|
||||
isWarning = true
|
||||
warningMsg = "command targets paths outside project"
|
||||
}
|
||||
if prefix := extractBashPrefix(cmd); prefix != "" {
|
||||
colonIdx := strings.Index(prefix, ":")
|
||||
if colonIdx != -1 {
|
||||
cmdName := prefix[:colonIdx]
|
||||
dirPath := prefix[colonIdx+1:]
|
||||
if dirPath != "./" {
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
|
||||
} else {
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
isWarning = isCommandOutsideCwd(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Run interactive selector
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
|
||||
if err != nil {
|
||||
term.Restore(fd, oldState)
|
||||
return ApprovalResult{Decision: ApprovalDeny}, err
|
||||
@@ -544,29 +433,27 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// formatToolDisplay creates the display string for a tool call.
|
||||
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
var sb strings.Builder
|
||||
displayName := ToolDisplayName(toolName)
|
||||
|
||||
// For bash, show command directly
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// For web search, show query and internet notice
|
||||
// For web search, show query
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||
if len(args) > 0 {
|
||||
sb.WriteString("\nArguments: ")
|
||||
first := true
|
||||
@@ -583,28 +470,24 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
|
||||
// selectorState holds the state for the interactive selector
|
||||
type selectorState struct {
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command has warning
|
||||
warningMessage string // dynamic warning message to display
|
||||
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command targets paths outside cwd (red box)
|
||||
}
|
||||
|
||||
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
|
||||
state := &selectorState{
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
warningMessage: warningMessage,
|
||||
allowlistInfo: allowlistInfo,
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
}
|
||||
|
||||
// Get terminal size
|
||||
@@ -764,7 +647,7 @@ func wrapText(text string, maxWidth int) []string {
|
||||
|
||||
// getHintLines returns the hint text wrapped to terminal width
|
||||
func getHintLines(state *selectorState) []string {
|
||||
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
|
||||
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||
if state.termWidth >= len(hint)+1 {
|
||||
return []string{hint}
|
||||
}
|
||||
@@ -774,70 +657,86 @@ func getHintLines(state *selectorState) []string {
|
||||
|
||||
// calculateTotalLines calculates how many lines the selector will use
|
||||
func calculateTotalLines(state *selectorState) int {
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
|
||||
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||
warningLines := 0
|
||||
if state.isWarning {
|
||||
warningLines = 2 // warning line + blank line after
|
||||
warningLines = 1
|
||||
}
|
||||
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
}
|
||||
|
||||
// renderSelectorBox renders the selector (minimal, no box)
|
||||
// renderSelectorBox renders the complete selector box
|
||||
func renderSelectorBox(state *selectorState) {
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Draw warning line if needed
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
if state.warningMessage != "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Draw box top
|
||||
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw warning line if needed (inside the box)
|
||||
if state.isWarning {
|
||||
warning := "!! OUTSIDE PROJECT !!"
|
||||
padding := (state.innerWidth - len(warning)) / 2
|
||||
if padding < 0 {
|
||||
padding = 0
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||
}
|
||||
|
||||
// Draw tool info (plain white)
|
||||
// Draw tool info
|
||||
for _, line := range toolLines {
|
||||
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||
}
|
||||
|
||||
// Blank line separator
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Draw separator
|
||||
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw options with numbers (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 {
|
||||
if i == 2 { // Deny option - show with reason input beside it
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blank line before hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Draw box bottom
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw hint (dark grey)
|
||||
// Draw hint (may be multiple lines)
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
// Last line - no newline
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
@@ -849,39 +748,50 @@ func renderSelectorBox(state *selectorState) {
|
||||
func updateSelectorOptions(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the first option line
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (blank line) + numOptions
|
||||
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 {
|
||||
if i == 2 { // Deny option
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -895,26 +805,36 @@ func updateSelectorOptions(state *selectorState) {
|
||||
func updateReasonInput(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the Deny line (3rd option, index 2)
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
|
||||
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -938,10 +858,11 @@ func clearSelectorBox(state *selectorState) {
|
||||
// fallbackApproval handles approval when terminal control isn't available.
|
||||
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "choice: ")
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "Choice: ")
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
@@ -984,16 +905,19 @@ func (a *ApprovalManager) AllowedTools() []string {
|
||||
|
||||
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||
var label string
|
||||
displayName := ToolDisplayName(toolName)
|
||||
var status string
|
||||
var icon string
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
label = "Approved"
|
||||
status = "Approved"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalAlways:
|
||||
label = "Always allowed"
|
||||
status = "Always allowed"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalDeny:
|
||||
label = "Denied"
|
||||
status = "Denied"
|
||||
icon = "\033[31m✗\033[0m"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
@@ -1003,7 +927,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(cmd) > 40 {
|
||||
cmd = cmd[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
|
||||
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,11 +937,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(query) > 40 {
|
||||
query = query[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
|
||||
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||
}
|
||||
|
||||
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||
@@ -1027,78 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
|
||||
}
|
||||
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||
}
|
||||
|
||||
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
|
||||
// Returns true for Yes, false for No.
|
||||
func PromptYesNo(question string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
selected := 0 // 0 = Yes, 1 = No
|
||||
options := []string{"Yes", "No"}
|
||||
|
||||
// Hide cursor
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
||||
|
||||
renderYesNo := func() {
|
||||
// Move to start of line and clear
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
fmt.Fprintf(os.Stderr, "%s ", question)
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
renderYesNo()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 'y', 'Y':
|
||||
selected = 0
|
||||
renderYesNo()
|
||||
case 'n', 'N':
|
||||
selected = 1
|
||||
renderYesNo()
|
||||
case '\r', '\n': // Enter
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
|
||||
return selected == 0, nil
|
||||
case 3: // Ctrl+C
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return false, nil
|
||||
case 27: // Escape - could be arrow key
|
||||
// Read more bytes for arrow keys
|
||||
continue
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
// Arrow keys
|
||||
switch buf[2] {
|
||||
case 'D': // Left
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
renderYesNo()
|
||||
case 'C': // Right
|
||||
if selected < len(options)-1 {
|
||||
selected++
|
||||
}
|
||||
renderYesNo()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
command: "head -n 100",
|
||||
expected: "",
|
||||
},
|
||||
// Path traversal security tests
|
||||
{
|
||||
name: "path traversal - parent escape",
|
||||
command: "cat tools/../../etc/passwd",
|
||||
expected: "", // Should NOT create a prefix - path escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - deep escape",
|
||||
command: "cat tools/a/b/../../../etc/passwd",
|
||||
expected: "", // Normalizes to "../etc/passwd" - escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - absolute path",
|
||||
command: "cat /etc/passwd",
|
||||
expected: "", // Absolute paths should not create prefix
|
||||
},
|
||||
{
|
||||
name: "path with safe dotdot - normalized",
|
||||
command: "cat tools/subdir/../file.go",
|
||||
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Path traversal attack: should NOT be allowed
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
|
||||
t.Error("SECURITY: path traversal attack should NOT be allowed")
|
||||
}
|
||||
|
||||
// Another traversal variant
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
|
||||
t.Error("SECURITY: deep path traversal should NOT be allowed")
|
||||
}
|
||||
|
||||
// Valid subdirectory access should still work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed")
|
||||
}
|
||||
|
||||
// Safe ".." that normalizes to within allowed directory should work
|
||||
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
|
||||
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should allow subdirectories (hierarchical matching)
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should allow deeply nested subdirectories
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
|
||||
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should still allow same directory
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
|
||||
t.Error("expected cat tools/another.go to be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different base directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||
t.Error("expected cat src/main.go to NOT be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different command even in subdirectory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
|
||||
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
|
||||
}
|
||||
|
||||
// Should NOT allow similar but different directory name
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
|
||||
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow with forward slashes (Unix-style)
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should work with backslashes too (Windows-style) - normalized internally
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
|
||||
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
|
||||
}
|
||||
|
||||
// Mixed slashes should also work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
|
||||
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesHierarchicalPrefix(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Add prefix for "cat:tools/"
|
||||
am.prefixes["cat:tools/"] = true
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
prefix: "cat:tools/",
|
||||
expected: true, // exact match also passes HasPrefix - caller handles exact match first
|
||||
},
|
||||
{
|
||||
name: "subdirectory",
|
||||
prefix: "cat:tools/subdir/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
prefix: "cat:tools/a/b/c/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different base directory",
|
||||
prefix: "cat:src/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different command same path",
|
||||
prefix: "ls:tools/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "similar directory name",
|
||||
prefix: "cat:toolsbin/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid prefix format",
|
||||
prefix: "cattools",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := am.matchesHierarchicalPrefix(tt.prefix)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
|
||||
tt.prefix, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatApprovalResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
234
x/cmd/run.go
234
x/cmd/run.go
@@ -6,12 +6,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
@@ -24,101 +22,6 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
localModelTokenLimit = 4000
|
||||
|
||||
// defaultTokenLimit is the token limit for cloud/remote models.
|
||||
defaultTokenLimit = 10000
|
||||
|
||||
// charsPerToken is a rough estimate of characters per token.
|
||||
// TODO: Estimate tokens more accurately using tokenizer if available
|
||||
charsPerToken = 4
|
||||
)
|
||||
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
// TODO: Could also check other indicators of local vs cloud server
|
||||
func isLocalServer() bool {
|
||||
host := os.Getenv("OLLAMA_HOST")
|
||||
if host == "" {
|
||||
return true // Default is localhost:11434
|
||||
}
|
||||
|
||||
// Parse the URL to check host
|
||||
parsed, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return true // If can't parse, assume local
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
|
||||
}
|
||||
|
||||
// truncateToolOutput truncates tool output to prevent context overflow.
|
||||
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
|
||||
func truncateToolOutput(output, modelName string) string {
|
||||
var tokenLimit int
|
||||
if isLocalModel(modelName) && isLocalServer() {
|
||||
tokenLimit = localModelTokenLimit
|
||||
} else {
|
||||
tokenLimit = defaultTokenLimit
|
||||
}
|
||||
|
||||
maxChars := tokenLimit * charsPerToken
|
||||
if len(output) > maxChars {
|
||||
return output[:maxChars] + "\n... (output truncated)"
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
|
||||
func waitForOllamaSignin(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get signin URL from initial Whoami call
|
||||
_, err = client.Whoami(ctx)
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
|
||||
|
||||
// Poll until auth succeeds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
user, whoamiErr := client.Whoami(ctx)
|
||||
if whoamiErr == nil && user != nil && user.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
|
||||
return nil
|
||||
}
|
||||
// Still waiting, show dot
|
||||
fmt.Fprintf(os.Stderr, ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
@@ -134,9 +37,6 @@ type RunOptions struct {
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
Approval *agent.ApprovalManager
|
||||
|
||||
// YoloMode skips all tool approval prompts
|
||||
YoloMode bool
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -177,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -260,58 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for 401 Unauthorized - prompt user to sign in
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the chat request
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
|
||||
}
|
||||
|
||||
// Check for 500 errors (often tool parsing failures) - inform the model
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
|
||||
consecutiveErrors++
|
||||
p.StopAndClear()
|
||||
|
||||
if consecutiveErrors >= 3 {
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
|
||||
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
|
||||
// Include both the model's response and the error so it can learn
|
||||
assistantContent := fullResponse.String()
|
||||
if assistantContent == "" {
|
||||
assistantContent = "(empty response)"
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
|
||||
messages = append(messages,
|
||||
api.Message{Role: "user", Content: errorMsg},
|
||||
)
|
||||
|
||||
// Reset state and retry
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
@@ -321,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset consecutive error counter on success
|
||||
consecutiveErrors = 0
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||
break
|
||||
@@ -353,8 +197,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check if command is denied (dangerous pattern)
|
||||
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
|
||||
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||
@@ -364,21 +208,15 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
// TODO(parthsareen): re-enable with tighter scoped allowlist
|
||||
// if agent.IsAutoAllowed(cmd) {
|
||||
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
// skipApproval = true
|
||||
// }
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check approval (uses prefix matching for bash commands)
|
||||
// In yolo mode, skip all approval prompts
|
||||
if opts.YoloMode {
|
||||
if !skipApproval {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
}
|
||||
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||
@@ -406,30 +244,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
} else if !skipApproval {
|
||||
// Already allowed - show running indicator
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
@@ -437,7 +258,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
toolSuccess:
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
if toolResult != "" {
|
||||
@@ -449,12 +269,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResultForLLM,
|
||||
Content: toolResult,
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
}
|
||||
@@ -500,18 +317,17 @@ func truncateUTF8(s string, limit int) string {
|
||||
|
||||
// formatToolShort returns a short description of a tool call.
|
||||
func formatToolShort(toolName string, args map[string]any) string {
|
||||
displayName := agent.ToolDisplayName(toolName)
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
|
||||
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||
}
|
||||
}
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
|
||||
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||
}
|
||||
}
|
||||
return displayName
|
||||
return toolName
|
||||
}
|
||||
|
||||
// Helper types and functions for display
|
||||
@@ -633,8 +449,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
@@ -651,7 +466,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
supportsTools = false
|
||||
}
|
||||
|
||||
@@ -659,17 +474,14 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
// Check for OLLAMA_API_KEY for web search
|
||||
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
@@ -712,9 +524,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
@@ -737,7 +546,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "local model without suffix",
|
||||
modelName: "llama3.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "local model with version",
|
||||
modelName: "qwen2.5:7b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
modelName: "",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isLocalModel(tt.modelName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty host (default)",
|
||||
host: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "http://localhost:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "http://127.0.0.1:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "custom port on localhost",
|
||||
host: "http://localhost:8080",
|
||||
expected: true, // localhost is always considered local
|
||||
},
|
||||
{
|
||||
name: "remote host",
|
||||
host: "http://ollama.example.com:11434",
|
||||
expected: true, // has :11434
|
||||
},
|
||||
{
|
||||
name: "remote host different port",
|
||||
host: "http://ollama.example.com:8080",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := isLocalServer()
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolOutput(t *testing.T) {
|
||||
// Create outputs of different sizes
|
||||
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
|
||||
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
|
||||
for i := range localLimitOutput {
|
||||
localLimitOutput[i] = 'a'
|
||||
}
|
||||
for i := range defaultLimitOutput {
|
||||
defaultLimitOutput[i] = 'b'
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
modelName string
|
||||
host string
|
||||
shouldTrim bool
|
||||
expectedLimit int
|
||||
}{
|
||||
{
|
||||
name: "short output local model",
|
||||
output: "hello world",
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output local model - trimmed at 4k",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output remote server - uses 10k limit",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "http://remote.example.com:8080",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := truncateToolOutput(tt.output, tt.modelName)
|
||||
|
||||
if tt.shouldTrim {
|
||||
maxLen := tt.expectedLimit * charsPerToken
|
||||
if len(result) > maxLen+50 { // +50 for the truncation message
|
||||
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
|
||||
}
|
||||
if result == tt.output {
|
||||
t.Error("expected output to be truncated but it wasn't")
|
||||
}
|
||||
} else {
|
||||
if result != tt.output {
|
||||
t.Error("expected output to not be truncated")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
# grammar
|
||||
|
||||
Grammar-constrained decoding for LLM outputs using MLX.
|
||||
|
||||
## Performance
|
||||
|
||||
Performance depends on hardware, vocabulary size, grammar, and whether you
|
||||
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
|
||||
setup.
|
||||
|
||||
### Design choices that keep masking fast
|
||||
|
||||
| Technique | Impact |
|
||||
|-----------|--------|
|
||||
| Precomputed token analysis | Terminal matches computed once at startup |
|
||||
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
|
||||
| Partitioned tokens | Exact matches separated from DP candidates |
|
||||
|
||||
### Comparison Notes
|
||||
|
||||
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
|
||||
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
|
||||
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
|
||||
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
)
|
||||
|
||||
// Use built-in JSON grammar
|
||||
g, _ := grammar.JSONGrammar()
|
||||
|
||||
// Or from JSON Schema (OpenAI-compatible)
|
||||
g, _ := schema.Grammar(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`)
|
||||
|
||||
// Or parse custom EBNF
|
||||
g, _ := grammar.ParseEBNF(myGrammar, "root")
|
||||
|
||||
// Create engine with model vocabulary
|
||||
engine, _ := grammar.NewEngine(g, vocab)
|
||||
defer engine.Close()
|
||||
|
||||
// Generation loop
|
||||
for !engine.IsComplete() {
|
||||
logits := model.Forward(tokens)
|
||||
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
|
||||
nextToken := sample(masked)
|
||||
engine.Accept(nextToken)
|
||||
}
|
||||
// Output conforms to the grammar when you only sample from masked tokens and call Accept
|
||||
```
|
||||
|
||||
## EBNF Syntax
|
||||
|
||||
```ebnf
|
||||
rule = expression . # Rule definition (ends with .)
|
||||
"literal" # Literal string
|
||||
"a" … "z" # Character range (inclusive)
|
||||
( a | b ) # Grouping with alternation
|
||||
[ optional ] # Optional (0 or 1)
|
||||
{ repeated } # Repetition (0 or more)
|
||||
```
|
||||
|
||||
### Example: JSON Grammar
|
||||
|
||||
```ebnf
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
```
|
||||
|
||||
### Example: Custom Schema
|
||||
|
||||
```ebnf
|
||||
root = "{" ws name_field "," ws age_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
|
||||
string = "\"" { char } "\"" .
|
||||
char = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
|
||||
ws = { " " | "\n" } .
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
OpenAI-compatible JSON Schema support with automatic EBNF generation:
|
||||
|
||||
```go
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {"$ref": "#/$defs/User"}
|
||||
},
|
||||
"required": ["user"],
|
||||
"$defs": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"role": {"enum": ["admin", "user", "guest"]}
|
||||
},
|
||||
"required": ["name", "email", "role"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
grammar, _ := schema.Grammar(schema)
|
||||
```
|
||||
|
||||
### Supported Features
|
||||
|
||||
| Feature | Example |
|
||||
|---------|---------|
|
||||
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
|
||||
| Objects | `properties`, `required` |
|
||||
| Arrays | `items`, `minItems`, `maxItems` |
|
||||
| Enums | `enum: ["a", "b", "c"]` |
|
||||
| Constants | `const: "value"` |
|
||||
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
|
||||
| References | `$ref: "#/$defs/Name"`, `$defs` |
|
||||
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test -tags mlx ./x/grammar/...
|
||||
|
||||
# Run benchmarks
|
||||
go test -tags mlx ./x/grammar/ -bench=.
|
||||
|
||||
# Compare with llama.cpp (outputs JSON)
|
||||
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
|
||||
|
||||
# Compare with a more complex schema
|
||||
go run -tags mlx ./x/grammar/cmd/compare \
|
||||
-gbnf x/grammar/cmd/compare/complex.gbnf \
|
||||
-schema x/grammar/cmd/compare/complex.schema.json \
|
||||
-vocab-size 128000 -iterations 500
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
|
||||
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
|
||||
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs
|
||||
@@ -1,161 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
|
||||
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
|
||||
type terminalTokenGroups struct {
|
||||
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
|
||||
ExactMatches []int32
|
||||
|
||||
// DPCandidates are tokens that start with this terminal but need DP validation
|
||||
DPCandidates []int
|
||||
}
|
||||
|
||||
// tokenAnalysis contains precomputed terminal matches for a token
|
||||
type tokenAnalysis struct {
|
||||
// The token string
|
||||
Token string
|
||||
|
||||
// TokenID in the vocabulary
|
||||
TokenID int
|
||||
|
||||
// Matches at each byte position
|
||||
// MatchesAtPos[i] = terminals matching at position i with their lengths
|
||||
MatchesAtPos [][]terminalMatch
|
||||
|
||||
// Fast path: if token exactly matches one terminal
|
||||
// -1 if no exact match
|
||||
exactMatch int
|
||||
|
||||
// Whether this token can be consumed at all (has at least one match)
|
||||
HasMatches bool
|
||||
}
|
||||
|
||||
// analyzer precomputes terminal matches for a vocabulary
|
||||
type analyzer struct {
|
||||
matcher *terminalMatcher
|
||||
analyses []tokenAnalysis // Indexed by token ID
|
||||
vocab []string
|
||||
|
||||
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
|
||||
// This enables direct slice appends instead of per-token branching
|
||||
tokensByTerminal []terminalTokenGroups
|
||||
}
|
||||
|
||||
// newAnalyzer creates an analyzer for the given vocabulary and terminals
|
||||
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
|
||||
a := &analyzer{
|
||||
matcher: matcher,
|
||||
analyses: make([]tokenAnalysis, len(vocab)),
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
// Precompute analysis for each token
|
||||
for i, token := range vocab {
|
||||
a.analyses[i] = a.analyze(token, i)
|
||||
}
|
||||
|
||||
// Build pre-partitioned token groups for fast ApplyMask
|
||||
a.buildTokenPartitions()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// analyze computes terminal matches for a single token
|
||||
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
|
||||
analysis := tokenAnalysis{
|
||||
Token: token,
|
||||
TokenID: tokenID,
|
||||
MatchesAtPos: make([][]terminalMatch, len(token)),
|
||||
exactMatch: -1,
|
||||
HasMatches: false,
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Compute matches at each position
|
||||
data := []byte(token)
|
||||
for pos := 0; pos < len(data); pos++ {
|
||||
matches := a.matcher.matchesAt(data, pos)
|
||||
analysis.MatchesAtPos[pos] = matches
|
||||
if len(matches) > 0 {
|
||||
analysis.HasMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match is only valid when a single terminal spans the entire token
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
var exactID int = -1
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
if match.Length != len(token) {
|
||||
continue
|
||||
}
|
||||
if exactID >= 0 && exactID != match.TerminalID {
|
||||
exactID = -1
|
||||
break
|
||||
}
|
||||
exactID = match.TerminalID
|
||||
}
|
||||
analysis.exactMatch = exactID
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// analysis returns the precomputed analysis for a token ID
|
||||
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
|
||||
if tokenID < 0 || tokenID >= len(a.analyses) {
|
||||
return tokenAnalysis{exactMatch: -1}
|
||||
}
|
||||
return a.analyses[tokenID]
|
||||
}
|
||||
|
||||
// vocabSize returns the vocabulary size
|
||||
func (a *analyzer) vocabSize() int {
|
||||
return len(a.vocab)
|
||||
}
|
||||
|
||||
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
|
||||
// This enables ApplyMask to use direct slice appends instead of per-token branching.
|
||||
func (a *analyzer) buildTokenPartitions() {
|
||||
numTerminals := a.matcher.terminalCount()
|
||||
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
|
||||
|
||||
for tokenID, analysis := range a.analyses {
|
||||
if !analysis.HasMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
if analysis.exactMatch >= 0 {
|
||||
// Token exactly matches one terminal - fast path (O(1) validation)
|
||||
tid := analysis.exactMatch
|
||||
a.tokensByTerminal[tid].ExactMatches = append(
|
||||
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
|
||||
} else {
|
||||
// Token needs DP validation - add to all terminals it can start with
|
||||
// This way, when a terminal is valid, we know exactly which tokens need DP
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
seen := make(map[int]bool)
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
tid := match.TerminalID
|
||||
if !seen[tid] {
|
||||
seen[tid] = true
|
||||
a.tokensByTerminal[tid].DPCandidates = append(
|
||||
a.tokensByTerminal[tid].DPCandidates, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminalGroups returns the pre-partitioned token groups for a terminal ID
|
||||
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
|
||||
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
|
||||
return terminalTokenGroups{}
|
||||
}
|
||||
return a.tokensByTerminal[terminalID]
|
||||
}
|
||||
@@ -1,648 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// visitedMapPool reduces allocations for visited maps in bridge operations
|
||||
var visitedMapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(map[stateStackKey]bool, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// getVisitedMap gets a map from the pool
|
||||
func getVisitedMap() map[stateStackKey]bool {
|
||||
return visitedMapPool.Get().(map[stateStackKey]bool)
|
||||
}
|
||||
|
||||
// putVisitedMap returns a map to the pool after clearing it
|
||||
func putVisitedMap(m map[stateStackKey]bool) {
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
visitedMapPool.Put(m)
|
||||
}
|
||||
|
||||
// parserConfig represents a pda state+stack combination
|
||||
type parserConfig struct {
|
||||
state state
|
||||
Stack []stackSymbol
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config
|
||||
func (c *parserConfig) clone() *parserConfig {
|
||||
newStack := make([]stackSymbol, len(c.Stack))
|
||||
copy(newStack, c.Stack)
|
||||
return &parserConfig{
|
||||
state: c.state,
|
||||
Stack: newStack,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns a unique key for this config for deduplication
|
||||
func (c *parserConfig) key() uint64 {
|
||||
h := fnv.New64a()
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
|
||||
h.Write(buf[:])
|
||||
for _, sym := range c.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// configSet represents a set of parser configurations (for nondeterminism)
|
||||
type configSet struct {
|
||||
configs []*parserConfig
|
||||
normalized bool // true if already deduplicated and sorted
|
||||
cachedSig uint64 // cached signature after normalization
|
||||
}
|
||||
|
||||
// newConfigSet creates a new config set with a single configuration
|
||||
func newConfigSet(state state, stack []stackSymbol) *configSet {
|
||||
return &configSet{
|
||||
configs: []*parserConfig{
|
||||
{state: state, Stack: stack},
|
||||
},
|
||||
normalized: true, // single config is already normalized
|
||||
}
|
||||
}
|
||||
|
||||
// normalize deduplicates and sorts configs for stable signatures
|
||||
func (c *configSet) normalize() {
|
||||
if c.normalized || len(c.configs) <= 1 {
|
||||
c.normalized = true
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using a map
|
||||
seen := make(map[uint64]*parserConfig, len(c.configs))
|
||||
for _, cfg := range c.configs {
|
||||
key := cfg.key()
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Extract unique configs
|
||||
unique := make([]*parserConfig, 0, len(seen))
|
||||
for _, cfg := range seen {
|
||||
unique = append(unique, cfg)
|
||||
}
|
||||
|
||||
// Sort by key for deterministic ordering
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return unique[i].key() < unique[j].key()
|
||||
})
|
||||
|
||||
c.configs = unique
|
||||
c.normalized = true
|
||||
}
|
||||
|
||||
// signature returns a hash for cache lookup (normalizes first)
|
||||
func (c *configSet) signature() uint64 {
|
||||
c.normalize()
|
||||
|
||||
// Return cached signature if available
|
||||
if c.cachedSig != 0 {
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
|
||||
// Hash number of configs
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
|
||||
h.Write(buf[:])
|
||||
|
||||
// Hash each config (already sorted)
|
||||
for _, cfg := range c.configs {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
|
||||
h.Write(buf[:])
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
|
||||
h.Write(buf[:])
|
||||
|
||||
for _, sym := range cfg.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
}
|
||||
|
||||
c.cachedSig = h.Sum64()
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
// isEmpty returns true if there are no configurations
|
||||
func (c *configSet) isEmpty() bool {
|
||||
return len(c.configs) == 0
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config set
|
||||
func (c *configSet) clone() *configSet {
|
||||
newConfigs := make([]*parserConfig, len(c.configs))
|
||||
for i, cfg := range c.configs {
|
||||
newConfigs[i] = cfg.clone()
|
||||
}
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// bridge connects token analysis to pda validation
|
||||
type bridge struct {
|
||||
pda *pda
|
||||
analyzer *analyzer
|
||||
}
|
||||
|
||||
// newBridge creates a new bridge
|
||||
func newBridge(pda *pda, analyzer *analyzer) *bridge {
|
||||
return &bridge{
|
||||
pda: pda,
|
||||
analyzer: analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenValid checks if token T can be consumed from the current config
|
||||
// This is the main entry point for token validation
|
||||
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
return b.canAcceptTerminal(config, terminal.Pattern)
|
||||
}
|
||||
|
||||
// General path: DP over (pos, config)
|
||||
return b.dpValidate(&analysis, config)
|
||||
}
|
||||
|
||||
// canAcceptTerminal checks if any config can accept the terminal
|
||||
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
|
||||
for _, cfg := range config.configs {
|
||||
if b.canConfigAcceptTerminal(cfg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canConfigAcceptTerminal checks if a single config can accept the terminal
|
||||
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
|
||||
// Use pooled visited map to reduce allocations
|
||||
visited := getVisitedMap()
|
||||
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
|
||||
putVisitedMap(visited)
|
||||
return result
|
||||
}
|
||||
|
||||
// tryAcceptTerminal recursively tries to accept a terminal from a state
|
||||
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Direct match
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dpValidate runs DP for multi-terminal tokens
|
||||
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
|
||||
// state: (pos, configSet)
|
||||
// Memoize by (pos, configSig)
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
memo := make(map[dpKey]bool)
|
||||
|
||||
var dp func(pos int, config *configSet) bool
|
||||
dp = func(pos int, config *configSet) bool {
|
||||
if pos == len(analysis.Token) {
|
||||
return true // Consumed entire token
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
|
||||
memo[key] = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
memo[key] = false
|
||||
return false
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// advanceConfig advances all configs that can accept the terminal
|
||||
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
|
||||
var newConfigs []*parserConfig
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
advanced := b.advanceSingleConfig(cfg, pattern)
|
||||
newConfigs = append(newConfigs, advanced...)
|
||||
}
|
||||
|
||||
if len(newConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// advanceSingleConfig advances a single config by accepting a terminal
|
||||
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
|
||||
var results []*parserConfig
|
||||
visited := getVisitedMap()
|
||||
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
|
||||
putVisitedMap(visited)
|
||||
return results
|
||||
}
|
||||
|
||||
// collectAdvanced collects all configs reachable by accepting the pattern
|
||||
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Match! Create new config after transition
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
*results = append(*results, &parserConfig{
|
||||
state: t.ToState,
|
||||
Stack: newStack,
|
||||
})
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTokens returns all token IDs that are valid from the given config
|
||||
func (b *bridge) validTokens(config *configSet) []int {
|
||||
var valid []int
|
||||
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
|
||||
if b.IsTokenValid(tokenID, config) {
|
||||
valid = append(valid, tokenID)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// acceptToken attempts to accept a token and returns the new config set
|
||||
// Returns nil if the token is not valid from this config
|
||||
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
newConfig.normalize()
|
||||
return newConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// General path: DP to find final config after consuming token
|
||||
return b.dpAccept(&analysis, config)
|
||||
}
|
||||
|
||||
// dpAccept runs DP to accept a multi-terminal token and return final config
|
||||
// Returns the union of all possible end configurations (preserves nondeterminism)
|
||||
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
// Memoize the configs reachable at each (pos, sig)
|
||||
memo := make(map[dpKey]*configSet)
|
||||
|
||||
var dp func(pos int, config *configSet) *configSet
|
||||
dp = func(pos int, config *configSet) *configSet {
|
||||
if pos == len(analysis.Token) {
|
||||
return config // Consumed entire token, return final config
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Collect all valid result configs from all possible paths
|
||||
var allConfigs []*parserConfig
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
finalConfig := dp(pos+match.Length, newConfig)
|
||||
if finalConfig != nil {
|
||||
// Collect all configs, don't return early
|
||||
allConfigs = append(allConfigs, finalConfig.configs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result: nil if no valid paths, normalized configSet otherwise
|
||||
var result *configSet
|
||||
if len(allConfigs) > 0 {
|
||||
result = &configSet{configs: allConfigs}
|
||||
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
|
||||
}
|
||||
memo[key] = result // Cache normalized result
|
||||
return result
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// isAccepting returns true if any config can reach an accepting state
|
||||
func (b *bridge) isAccepting(config *configSet) bool {
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config check
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canReachAccept checks if we can reach an accepting state via epsilon transitions
|
||||
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if b.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the given config
|
||||
func (b *bridge) validTerminals(config *configSet) []string {
|
||||
seen := make(map[string]bool)
|
||||
var terminals []string
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
|
||||
}
|
||||
|
||||
return terminals
|
||||
}
|
||||
|
||||
// collectValidTerminals collects all reachable terminals
|
||||
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" && !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*terminals = append(*terminals, t.Pattern)
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTerminalIDs returns the IDs of valid terminals from the given config
|
||||
func (b *bridge) validTerminalIDs(config *configSet) []int {
|
||||
seen := make(map[int]bool)
|
||||
var terminalIDs []int
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
|
||||
}
|
||||
|
||||
return terminalIDs
|
||||
}
|
||||
|
||||
// collectValidTerminalIDs collects IDs of all reachable terminals
|
||||
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" {
|
||||
// Look up terminal ID from pattern
|
||||
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
|
||||
seen[tid] = true
|
||||
*terminalIDs = append(*terminalIDs, tid)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
|
||||
|
||||
id-field ::= "\"id\"" ws ":" ws uuid
|
||||
kind-field ::= "\"kind\"" ws ":" ws kind
|
||||
items-field ::= "\"items\"" ws ":" ws items
|
||||
alt-field ::= "\"alt\"" ws ":" ws alt
|
||||
flags-field ::= "\"flags\"" ws ":" ws flags
|
||||
meta-field ::= "\"meta\"" ws ":" ws meta
|
||||
priority-field ::= "\"priority\"" ws ":" ws int
|
||||
|
||||
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
|
||||
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
|
||||
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
|
||||
source ::= "\"api\"" | "\"batch\"" | "\"import\""
|
||||
|
||||
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
|
||||
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
|
||||
|
||||
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
|
||||
item-sku ::= "\"sku\"" ws ":" ws string
|
||||
item-qty ::= "\"qty\"" ws ":" ws int
|
||||
item-status ::= "\"status\"" ws ":" ws status
|
||||
item-notes ::= "\"notes\"" ws ":" ws string
|
||||
|
||||
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
|
||||
meta-created ::= "\"created\"" ws ":" ws date-time
|
||||
meta-source ::= "\"source\"" ws ":" ws source
|
||||
meta-ip ::= "\"ip\"" ws ":" ws ipv4
|
||||
|
||||
alt ::= string | int | "null"
|
||||
|
||||
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
|
||||
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
|
||||
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
|
||||
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
|
||||
int ::= "-"? digit+
|
||||
digit ::= [0-9]
|
||||
hex ::= [0-9a-fA-F]
|
||||
|
||||
ws ::= [ \t\n\r]*
|
||||
@@ -1,46 +0,0 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "format": "uuid" },
|
||||
"kind": { "enum": ["order", "invoice", "shipment"] },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sku": { "type": "string" },
|
||||
"qty": { "type": "integer" },
|
||||
"status": { "enum": ["new", "backorder", "shipped"] },
|
||||
"notes": { "type": "string" }
|
||||
},
|
||||
"required": ["sku", "qty", "status", "notes"]
|
||||
}
|
||||
},
|
||||
"alt": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"minItems": 0,
|
||||
"maxItems": 4,
|
||||
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": { "type": "string", "format": "date-time" },
|
||||
"source": { "enum": ["api", "batch", "import"] },
|
||||
"ip": { "type": "string", "format": "ipv4" }
|
||||
},
|
||||
"required": ["created", "source", "ip"]
|
||||
},
|
||||
"priority": { "type": "integer" }
|
||||
},
|
||||
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
const jsonGBNF = `
|
||||
root ::= value
|
||||
value ::= object | array | string | number | "true" | "false" | "null"
|
||||
object ::= "{" ws "}" | "{" members "}"
|
||||
members ::= member ("," member)*
|
||||
member ::= ws string ws ":" element
|
||||
array ::= "[" ws "]" | "[" elements "]"
|
||||
elements ::= element ("," element)*
|
||||
element ::= ws value ws
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
number ::= "-"? integer fraction? exponent?
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
fraction ::= "." [0-9]+
|
||||
exponent ::= [eE] [+-]? [0-9]+
|
||||
ws ::= [ \t\n\r]*
|
||||
`
|
||||
|
||||
type result struct {
|
||||
vocabSize int `json:"vocab_size"`
|
||||
Iterations int `json:"iterations"`
|
||||
Warmup int `json:"warmup"`
|
||||
ConstrainedSource string `json:"constrained_source"`
|
||||
LlamaSource string `json:"llama_source"`
|
||||
LlamaApply string `json:"llama_apply"`
|
||||
ConstrainedGraph string `json:"constrained_graph"`
|
||||
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
||||
EvalOnly string `json:"eval_only,omitempty"`
|
||||
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
||||
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
||||
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
||||
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
||||
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
||||
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
||||
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
||||
startRule = flag.String("start", "root", "Start rule for EBNF")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
||||
fmt.Fprintln(os.Stderr, "invalid flags")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
vocab := createVocab(*vocabSize)
|
||||
|
||||
if *schemaPath != "" && *ebnfPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
var constrainedSource string
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
switch {
|
||||
case *schemaPath != "":
|
||||
data, readErr := os.ReadFile(*schemaPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = schema.Grammar(string(data))
|
||||
constrainedSource = "schema:" + *schemaPath
|
||||
case *ebnfPath != "":
|
||||
data, readErr := os.ReadFile(*ebnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
||||
constrainedSource = "ebnf:" + *ebnfPath
|
||||
default:
|
||||
compiled, err = grammar.JSONGrammar()
|
||||
constrainedSource = "json"
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
engine, err := grammar.NewEngine(compiled, vocab)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(*vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
masked := engine.ApplyMask(logits)
|
||||
if *withEval {
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
graphAvg := measure(*iterations, func() {
|
||||
_ = engine.ApplyMask(logits)
|
||||
})
|
||||
|
||||
var evalAvg time.Duration
|
||||
var evalOnlyAvg time.Duration
|
||||
if *withEval {
|
||||
evalOnlyAvg = measure(*iterations, func() {
|
||||
baseline := mlx.MulScalar(logits, 1)
|
||||
mlx.Eval(baseline)
|
||||
baseline.Free()
|
||||
})
|
||||
|
||||
evalAvg = measure(*iterations, func() {
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
})
|
||||
}
|
||||
|
||||
vocabIDs := make([]uint32, *vocabSize)
|
||||
for i := range vocabIDs {
|
||||
vocabIDs[i] = uint32(i)
|
||||
}
|
||||
eogTokens := []int32{0}
|
||||
|
||||
gbnf := jsonGBNF
|
||||
llamaSource := "json"
|
||||
if *gbnfPath != "" {
|
||||
data, readErr := os.ReadFile(*gbnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
gbnf = string(data)
|
||||
llamaSource = *gbnfPath
|
||||
}
|
||||
|
||||
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
||||
if llamaGrammar == nil {
|
||||
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer llamaGrammar.Free()
|
||||
|
||||
llamaTokens := make([]llama.TokenData, *vocabSize)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
}
|
||||
|
||||
llamaAvg := measure(*iterations, func() {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
})
|
||||
|
||||
out := result{
|
||||
vocabSize: *vocabSize,
|
||||
Iterations: *iterations,
|
||||
Warmup: *warmup,
|
||||
LlamaApply: llamaAvg.String(),
|
||||
ConstrainedGraph: graphAvg.String(),
|
||||
ConstrainedSource: constrainedSource,
|
||||
LlamaSource: llamaSource,
|
||||
}
|
||||
if *withEval {
|
||||
out.ConstrainedWithEval = evalAvg.String()
|
||||
out.EvalOnly = evalOnlyAvg.String()
|
||||
if evalAvg > evalOnlyAvg {
|
||||
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
||||
} else {
|
||||
out.ConstrainedEvalNet = "0s"
|
||||
}
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
if err := enc.Encode(out); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func measure(iterations int, fn func()) time.Duration {
|
||||
start := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
fn()
|
||||
}
|
||||
return time.Since(start) / time.Duration(iterations)
|
||||
}
|
||||
|
||||
func createVocab(size int) []string {
|
||||
vocab := make([]string, size)
|
||||
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < size {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(jsonTokens); i < size; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Grammar is the compiled form of an EBNF grammar.
|
||||
// It contains terminals, parse tables, and the start state.
|
||||
// Use ParseEBNF or JSONGrammar to create a Grammar.
|
||||
type Grammar struct {
|
||||
// The underlying pda
|
||||
pda *pda
|
||||
|
||||
// Compiled terminal matcher
|
||||
matcher *terminalMatcher
|
||||
}
|
||||
|
||||
// ParseEBNF compiles an EBNF grammar string into a Grammar.
|
||||
// startRule is the name of the start rule (e.g., "root", "json").
|
||||
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
|
||||
pda, err := compileString(ebnf, startRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
|
||||
}
|
||||
|
||||
matcher, err := compileTerminalsStrict(pda)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile terminals: %w", err)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
pda: pda,
|
||||
matcher: matcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JSONGrammar returns the compiled JSON grammar.
|
||||
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
|
||||
func JSONGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
|
||||
// Use this when you want to ensure the output is a JSON object (starts with {).
|
||||
func JSONObjectGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONObjectGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// compileTerminalsStrict builds a matcher that properly handles:
|
||||
// - Escaped literals ("\n", \"", \uXXXX)
|
||||
// - Unicode ranges (rune-based, not byte-based)
|
||||
// - Rejects unsupported patterns with an error (no silent fallback)
|
||||
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
|
||||
m := &terminalMatcher{
|
||||
literalTrie: &trieNode{terminalID: -1},
|
||||
ranges: make([]terminal, 0),
|
||||
terminals: make([]terminal, 0, len(pda.Terminals)),
|
||||
patternToID: make(map[string]int),
|
||||
}
|
||||
|
||||
// Track which pattern produced each unescaped value for collision detection
|
||||
unescapedSource := make(map[string]string) // unescaped -> original pattern
|
||||
|
||||
for i, pattern := range pda.Terminals {
|
||||
terminal, err := parseTerminalPattern(pattern, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
if terminal.Type == terminalLiteral {
|
||||
// Use the unescaped pattern for trie matching
|
||||
m.addLiteralToTrie(terminal.Unescaped, i)
|
||||
|
||||
// Detect collisions between literals that unescape to the same value
|
||||
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
|
||||
if existingPattern != pattern {
|
||||
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
|
||||
existingPattern, pattern, terminal.Unescaped)
|
||||
}
|
||||
} else {
|
||||
unescapedSource[terminal.Unescaped] = pattern
|
||||
}
|
||||
} else if terminal.Type == terminalRange {
|
||||
m.ranges = append(m.ranges, terminal)
|
||||
}
|
||||
|
||||
m.terminals = append(m.terminals, terminal)
|
||||
m.patternToID[pattern] = i
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// parseTerminalPattern parses a terminal pattern and returns a terminal.
|
||||
// Supports:
|
||||
// - Literal strings (with escape sequences)
|
||||
// - Character ranges [X-Y] (unicode-aware)
|
||||
func parseTerminalPattern(pattern string, id int) (terminal, error) {
|
||||
if len(pattern) == 0 {
|
||||
return terminal{}, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
// Check for range pattern: [X-Y]
|
||||
if isUnicodeRangePattern(pattern) {
|
||||
lowRune, highRune, err := parseUnicodeRange(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, err
|
||||
}
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalRange,
|
||||
Pattern: pattern,
|
||||
Unescaped: pattern,
|
||||
LowRune: lowRune,
|
||||
HighRune: highRune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// It's a literal - unescape it
|
||||
unescaped, err := unescapeLiteral(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
|
||||
}
|
||||
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalLiteral,
|
||||
Pattern: pattern,
|
||||
Unescaped: unescaped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
|
||||
func isUnicodeRangePattern(pattern string) bool {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return false
|
||||
}
|
||||
// Find the dash that separates low-high
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
dashIdx := strings.Index(inner, "-")
|
||||
// Handle escaped dash at start
|
||||
if dashIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseUnicodeRange parses [X-Y] into low and high runes
|
||||
func parseUnicodeRange(pattern string) (rune, rune, error) {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return 0, 0, fmt.Errorf("invalid range pattern")
|
||||
}
|
||||
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
|
||||
// Simple case: [a-z] where a and z are single chars
|
||||
if len(inner) == 3 && inner[1] == '-' {
|
||||
return rune(inner[0]), rune(inner[2]), nil
|
||||
}
|
||||
|
||||
// Handle escaped characters like [\u0000-\uFFFF]
|
||||
dashIdx := findRangeDash(inner)
|
||||
if dashIdx < 0 {
|
||||
return 0, 0, fmt.Errorf("no dash in range")
|
||||
}
|
||||
|
||||
lowStr := inner[:dashIdx]
|
||||
highStr := inner[dashIdx+1:]
|
||||
|
||||
lowRune, err := parseRune(lowStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
|
||||
}
|
||||
|
||||
highRune, err := parseRune(highStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
|
||||
}
|
||||
|
||||
if lowRune > highRune {
|
||||
return 0, 0, fmt.Errorf("low bound > high bound")
|
||||
}
|
||||
|
||||
return lowRune, highRune, nil
|
||||
}
|
||||
|
||||
// findRangeDash finds the dash separating low-high in a range pattern
|
||||
func findRangeDash(inner string) int {
|
||||
i := 0
|
||||
for i < len(inner) {
|
||||
if inner[i] == '\\' && i+1 < len(inner) {
|
||||
// Skip escape sequence
|
||||
if inner[i+1] == 'u' && i+6 <= len(inner) {
|
||||
i += 6 // \uXXXX
|
||||
} else {
|
||||
i += 2 // \n, \t, etc.
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inner[i] == '-' && i > 0 {
|
||||
return i
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseRune parses a single rune from a string (handles escapes)
|
||||
func parseRune(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty rune")
|
||||
}
|
||||
|
||||
// Handle escape sequences
|
||||
if s[0] == '\\' {
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("incomplete escape")
|
||||
}
|
||||
switch s[1] {
|
||||
case 'n':
|
||||
return '\n', nil
|
||||
case 't':
|
||||
return '\t', nil
|
||||
case 'r':
|
||||
return '\r', nil
|
||||
case '\\':
|
||||
return '\\', nil
|
||||
case '"':
|
||||
return '"', nil
|
||||
case '\'':
|
||||
return '\'', nil
|
||||
case 'u':
|
||||
if len(s) < 6 {
|
||||
return 0, fmt.Errorf("incomplete unicode escape")
|
||||
}
|
||||
val, err := strconv.ParseInt(s[2:6], 16, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid unicode escape: %w", err)
|
||||
}
|
||||
return rune(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Plain character
|
||||
r, _ := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError {
|
||||
return 0, fmt.Errorf("invalid utf8")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// unescapeLiteral unescapes a literal pattern string
|
||||
func unescapeLiteral(pattern string) (string, error) {
|
||||
// Try strconv.Unquote if it looks quoted
|
||||
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
|
||||
unquoted, err := strconv.Unquote(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unquoted, nil
|
||||
}
|
||||
|
||||
// If no backslashes, return as-is
|
||||
if !strings.Contains(pattern, "\\") {
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// Manual unescape
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if pattern[i] == '\\' && i+1 < len(pattern) {
|
||||
switch pattern[i+1] {
|
||||
case 'n':
|
||||
result.WriteByte('\n')
|
||||
i += 2
|
||||
case 't':
|
||||
result.WriteByte('\t')
|
||||
i += 2
|
||||
case 'r':
|
||||
result.WriteByte('\r')
|
||||
i += 2
|
||||
case '\\':
|
||||
result.WriteByte('\\')
|
||||
i += 2
|
||||
case '"':
|
||||
result.WriteByte('"')
|
||||
i += 2
|
||||
case '\'':
|
||||
result.WriteByte('\'')
|
||||
i += 2
|
||||
case 'u':
|
||||
if i+6 <= len(pattern) {
|
||||
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape at %d", i)
|
||||
}
|
||||
result.WriteRune(rune(val))
|
||||
i += 6
|
||||
} else {
|
||||
return "", fmt.Errorf("incomplete unicode escape at %d", i)
|
||||
}
|
||||
default:
|
||||
// Reject unknown escape sequences
|
||||
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(pattern[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// maskCache provides LRU caching for computed masks.
|
||||
type maskCache struct {
|
||||
cache map[uint64]*list.Element
|
||||
order *list.List
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type maskEntry struct {
|
||||
sig uint64
|
||||
mask *mlx.Array
|
||||
}
|
||||
|
||||
// newMaskCache creates a new mask cache with the given max size
|
||||
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
|
||||
func newMaskCache(maxSize int) *maskCache {
|
||||
if maxSize <= 0 {
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: 0, // Signals disabled
|
||||
}
|
||||
}
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a cached mask, returning nil if not found.
|
||||
// Updates LRU order on cache hit.
|
||||
func (c *maskCache) get(sig uint64) *mlx.Array {
|
||||
if c.maxSize <= 0 {
|
||||
return nil // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if elem, ok := c.cache[sig]; ok {
|
||||
c.order.MoveToFront(elem)
|
||||
return elem.Value.(*maskEntry).mask
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put stores a mask in the cache with LRU eviction.
|
||||
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
|
||||
if c.maxSize <= 0 {
|
||||
return // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, exists := c.cache[sig]; exists {
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity (safe since maxSize > 0)
|
||||
if c.order.Len() >= c.maxSize {
|
||||
oldest := c.order.Back()
|
||||
if oldest != nil {
|
||||
entry := oldest.Value.(*maskEntry)
|
||||
entry.mask.Free()
|
||||
delete(c.cache, entry.sig)
|
||||
c.order.Remove(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
|
||||
c.cache[sig] = elem
|
||||
}
|
||||
|
||||
// clear frees all cached masks.
|
||||
func (c *maskCache) clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
|
||||
elem.Value.(*maskEntry).mask.Free()
|
||||
}
|
||||
c.cache = make(map[uint64]*list.Element)
|
||||
c.order.Init()
|
||||
}
|
||||
|
||||
// size returns the number of cached masks.
|
||||
func (c *maskCache) size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// Engine applies grammar constraints to model outputs using MLX.
|
||||
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
|
||||
type Engine struct {
|
||||
// The compiled grammar
|
||||
grammar *Grammar
|
||||
|
||||
// bridge for token validation
|
||||
bridge *bridge
|
||||
analyzer *analyzer
|
||||
|
||||
// Current parser state (configSet for nondeterminism)
|
||||
configSet *configSet
|
||||
|
||||
// Token vocabulary from the model
|
||||
vocab []string
|
||||
tokenToID map[string]int // O(1) lookup for AcceptString
|
||||
|
||||
// Mask cache: configSig → valid token mask (LRU)
|
||||
maskCache *maskCache
|
||||
|
||||
// Cached negative infinity mask for invalid tokens
|
||||
negInfMask *mlx.Array
|
||||
|
||||
// Threshold for comparison (0.5 since mask values are 0 or 1)
|
||||
threshold *mlx.Array
|
||||
|
||||
// Vocabulary size
|
||||
vocabSize int32
|
||||
|
||||
// Reusable buffers for candidate filtering (avoid allocations)
|
||||
candidateMark []bool // indexed by tokenID, true if in candidate set
|
||||
touched []int // tokenIDs that were marked (for reset)
|
||||
dpCandidates []int // candidates requiring DP validation
|
||||
|
||||
// Reusable buffer for valid token indices (for GPU scatter)
|
||||
validTokenIDs []int32
|
||||
}
|
||||
|
||||
// EngineOption configures an Engine
|
||||
type EngineOption func(*Engine)
|
||||
|
||||
// WithMaskCacheSize sets the mask cache size (default 1024)
|
||||
func WithMaskCacheSize(size int) EngineOption {
|
||||
return func(e *Engine) {
|
||||
e.maskCache = newMaskCache(size)
|
||||
}
|
||||
}
|
||||
|
||||
// NewEngine creates a new constrained decoding engine.
|
||||
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
|
||||
// vocab is the list of token strings from the model's tokenizer.
|
||||
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
|
||||
if grammar == nil {
|
||||
return nil, fmt.Errorf("grammar cannot be nil")
|
||||
}
|
||||
|
||||
// Build analyzer and bridge
|
||||
analyzer := newAnalyzer(vocab, grammar.matcher)
|
||||
bridge := newBridge(grammar.pda, analyzer)
|
||||
|
||||
// Initialize config set from pda initial state
|
||||
initialConfig := newConfigSet(grammar.pda.StartState, nil)
|
||||
|
||||
// Build token lookup map for O(1) AcceptString
|
||||
tokenToID := make(map[string]int, len(vocab))
|
||||
for i, tok := range vocab {
|
||||
tokenToID[tok] = i
|
||||
}
|
||||
|
||||
e := &Engine{
|
||||
grammar: grammar,
|
||||
bridge: bridge,
|
||||
analyzer: analyzer,
|
||||
configSet: initialConfig,
|
||||
vocab: vocab,
|
||||
tokenToID: tokenToID,
|
||||
maskCache: newMaskCache(1024),
|
||||
vocabSize: int32(len(vocab)),
|
||||
candidateMark: make([]bool, len(vocab)),
|
||||
touched: make([]int, 0, 10000),
|
||||
validTokenIDs: make([]int32, 0, 10000),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
// Create the negative infinity mask and threshold
|
||||
if e.vocabSize > 0 {
|
||||
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
|
||||
mlx.Keep(e.negInfMask)
|
||||
|
||||
e.threshold = mlx.NewScalarArray(0.5)
|
||||
mlx.Keep(e.threshold)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// ApplyMask applies grammar constraints to logits.
|
||||
// Returns logits with invalid tokens set to -inf.
|
||||
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
|
||||
sig := e.configSet.signature()
|
||||
|
||||
// Check state cache first (exact state match)
|
||||
if cached := e.maskCache.get(sig); cached != nil {
|
||||
condition := mlx.GreaterEqual(cached, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Compute valid tokens using candidate filtering:
|
||||
// 1. Get valid terminal IDs from current grammar state
|
||||
// 2. Get candidate tokens (those that START with valid terminals)
|
||||
// 3. Run DP validation only on candidates
|
||||
// This is O(candidates) instead of O(vocab_size)
|
||||
|
||||
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
|
||||
|
||||
// Use pre-partitioned token groups for fast candidate building
|
||||
// This eliminates per-token branching - just direct slice appends
|
||||
e.validTokenIDs = e.validTokenIDs[:0]
|
||||
e.dpCandidates = e.dpCandidates[:0]
|
||||
e.touched = e.touched[:0]
|
||||
|
||||
for _, tid := range validTerminalIDs {
|
||||
groups := e.analyzer.terminalGroups(tid)
|
||||
|
||||
// Direct append of exact matches (no per-token check needed)
|
||||
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
|
||||
|
||||
// Collect DP candidates (may have duplicates across terminals)
|
||||
for _, tokenID := range groups.DPCandidates {
|
||||
if !e.candidateMark[tokenID] {
|
||||
e.candidateMark[tokenID] = true
|
||||
e.dpCandidates = append(e.dpCandidates, tokenID)
|
||||
e.touched = append(e.touched, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset marks for next call
|
||||
for _, id := range e.touched {
|
||||
e.candidateMark[id] = false
|
||||
}
|
||||
|
||||
for _, tokenID := range e.dpCandidates {
|
||||
if e.bridge.IsTokenValid(tokenID, e.configSet) {
|
||||
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
|
||||
}
|
||||
}
|
||||
|
||||
// Create and cache the mask on GPU using index updates
|
||||
mask := mlx.Zeros([]int32{e.vocabSize})
|
||||
if len(e.validTokenIDs) > 0 {
|
||||
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
|
||||
values := mlx.Ones(int32(len(e.validTokenIDs)))
|
||||
mask = mlx.PutAlongAxis(mask, indices, values, 0)
|
||||
}
|
||||
mlx.Keep(mask)
|
||||
|
||||
// Cache by state signature
|
||||
e.maskCache.put(sig, mask)
|
||||
|
||||
// Apply mask
|
||||
condition := mlx.GreaterEqual(mask, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Accept processes a token and updates the parser state.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) Accept(tokenID int) bool {
|
||||
if tokenID < 0 || tokenID >= len(e.vocab) {
|
||||
return false
|
||||
}
|
||||
|
||||
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
|
||||
if newConfig == nil {
|
||||
return false
|
||||
}
|
||||
e.configSet = newConfig
|
||||
return true
|
||||
}
|
||||
|
||||
// AcceptString processes a token string directly.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) AcceptString(token string) bool {
|
||||
if id, ok := e.tokenToID[token]; ok {
|
||||
return e.Accept(id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsComplete returns true if the current state is accepting.
|
||||
func (e *Engine) IsComplete() bool {
|
||||
return e.bridge.isAccepting(e.configSet)
|
||||
}
|
||||
|
||||
// Reset resets the engine to initial state.
|
||||
func (e *Engine) Reset() {
|
||||
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
|
||||
}
|
||||
|
||||
// validTokens returns the indices of tokens that are currently valid.
|
||||
func (e *Engine) validTokens() []int {
|
||||
return e.bridge.validTokens(e.configSet)
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the current state.
|
||||
func (e *Engine) validTerminals() []string {
|
||||
return e.bridge.validTerminals(e.configSet)
|
||||
}
|
||||
|
||||
// Close releases MLX resources.
|
||||
func (e *Engine) Close() {
|
||||
if e.maskCache != nil {
|
||||
e.maskCache.clear()
|
||||
}
|
||||
if e.negInfMask != nil {
|
||||
e.negInfMask.Free()
|
||||
}
|
||||
if e.threshold != nil {
|
||||
e.threshold.Free()
|
||||
}
|
||||
}
|
||||
@@ -1,414 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newBenchEngine creates a JSON engine for benchmarks
|
||||
func newBenchEngine(b *testing.B, vocab []string) *Engine {
|
||||
b.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Vocabulary sizes to test (matching real models)
|
||||
var vocabSizes = []int{
|
||||
32000, // Llama 2
|
||||
128000, // Llama 3
|
||||
256000, // Large models
|
||||
}
|
||||
|
||||
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
|
||||
func createBenchVocabN(n int) []string {
|
||||
vocab := make([]string, n)
|
||||
|
||||
// JSON structural tokens (first 20)
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"", "'",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < n {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// String tokens (indices 20-1000)
|
||||
stringIdx := 20
|
||||
for i := 0; i < 980 && stringIdx+i < n; i++ {
|
||||
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
|
||||
}
|
||||
|
||||
// Number tokens (indices 1000-2000)
|
||||
numberIdx := 1000
|
||||
for i := 0; i < 1000 && numberIdx+i < n; i++ {
|
||||
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
|
||||
// Generic tokens (rest)
|
||||
for i := 2000; i < n; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
|
||||
// ============ Core Performance Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMask_32k measures mask application with 32k vocab
|
||||
func BenchmarkApplyMask_32k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 32000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_128k measures mask application with 128k vocab
|
||||
func BenchmarkApplyMask_128k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 128000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_256k measures mask application with 256k vocab
|
||||
func BenchmarkApplyMask_256k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 256000)
|
||||
}
|
||||
|
||||
func benchmarkApplyMask(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(vocabSize), "vocab_size")
|
||||
}
|
||||
|
||||
// ============ state-Dependent Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
|
||||
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
e.AcceptString("{")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkApplyMaskMidObject measures mask in middle of object
|
||||
func BenchmarkApplyMaskMidObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// state: {"key": _value_
|
||||
e.AcceptString("{")
|
||||
e.AcceptString("\"key\"")
|
||||
e.AcceptString(":")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Token Sequence Benchmarks ============
|
||||
|
||||
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
|
||||
func BenchmarkSequence_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
|
||||
func BenchmarkSequence_NestedObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
|
||||
func BenchmarkSequence_LargeArray(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Build sequence: [1, 2, 3, ..., 50]
|
||||
sequence := []string{"["}
|
||||
for i := 1; i <= 50; i++ {
|
||||
sequence = append(sequence, fmt.Sprintf("%d", i))
|
||||
if i < 50 {
|
||||
sequence = append(sequence, ",")
|
||||
}
|
||||
}
|
||||
sequence = append(sequence, "]")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
|
||||
func BenchmarkSequence_MixedTypes(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{",
|
||||
"\"name\"", ":", "\"test\"", ",",
|
||||
"\"count\"", ":", "42", ",",
|
||||
"\"enabled\"", ":", "true", ",",
|
||||
"\"data\"", ":", "null", ",",
|
||||
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
|
||||
"}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// ============ Component Benchmarks ============
|
||||
|
||||
// BenchmarkValidInputs measures pda valid input computation
|
||||
func BenchmarkValidInputs(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.validTerminals()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateTransition measures pda state transition
|
||||
func BenchmarkStateTransition(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
|
||||
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // Graph only, no eval
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNewEngine measures one-time engine initialization.
|
||||
func BenchmarkNewEngine_32k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkNewEngine_128k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkNewEngine(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newBenchEngine(b, vocab)
|
||||
e.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Memory Benchmarks ============
|
||||
|
||||
func BenchmarkMemoryAllocs_32k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkMemoryAllocs_128k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
|
||||
|
||||
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
|
||||
// This simulates adding mask to LLM compute graph
|
||||
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // No Eval - just build graph
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
|
||||
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
var lastMasked *mlx.Array
|
||||
for _, token := range sequence {
|
||||
lastMasked = e.ApplyMask(logits) // Build graph only
|
||||
e.AcceptString(token)
|
||||
}
|
||||
mlx.Eval(lastMasked) // Single eval at end
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newTestEngine creates a JSON engine for testing
|
||||
func newTestEngine(t testing.TB, vocab []string) *Engine {
|
||||
t.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Mock vocabulary for testing
|
||||
func testVocab() []string {
|
||||
return []string{
|
||||
"{", // 0: object start
|
||||
"}", // 1: object end
|
||||
"[", // 2: array start
|
||||
"]", // 3: array end
|
||||
":", // 4: colon
|
||||
",", // 5: comma
|
||||
"\"key\"", // 6: string (quoted)
|
||||
"\"val\"", // 7: string (quoted)
|
||||
"123", // 8: number
|
||||
"-42.5", // 9: number
|
||||
"true", // 10: boolean
|
||||
"false", // 11: boolean
|
||||
"null", // 12: null
|
||||
" ", // 13: whitespace (should be ignored)
|
||||
"\n", // 14: whitespace (should be ignored)
|
||||
"subword", // 15: bare word (NOT valid JSON - requires quotes)
|
||||
"hello", // 16: bare word (NOT valid JSON - requires quotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != int32(len(vocab)) {
|
||||
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
|
||||
}
|
||||
|
||||
// Verify grammar is set
|
||||
if e.grammar == nil {
|
||||
t.Error("grammar should not be nil")
|
||||
}
|
||||
|
||||
// Verify analyzer is set
|
||||
if e.analyzer == nil {
|
||||
t.Error("analyzer should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineValidTokens(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// At start, any value type should be valid
|
||||
validTokens := e.validTokens()
|
||||
|
||||
// Should include object start, array start, strings, numbers, booleans, null
|
||||
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
|
||||
// (JSON strings must be quoted)
|
||||
expectedTokens := map[int]bool{
|
||||
0: true, // {
|
||||
2: true, // [
|
||||
6: true, // "key"
|
||||
7: true, // "val"
|
||||
8: true, // 123
|
||||
9: true, // -42.5
|
||||
10: true, // true
|
||||
11: true, // false
|
||||
12: true, // null
|
||||
}
|
||||
|
||||
// Check that expected tokens are present
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
for idx := range expectedTokens {
|
||||
if !validSet[idx] {
|
||||
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if validSet[15] || validSet[16] {
|
||||
t.Error("bare words should not be valid JSON at the start state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAccept(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { should work
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
|
||||
// After {, valid tokens should be STRING or }
|
||||
validTokens := e.validTokens()
|
||||
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
// STRING tokens (indices 6, 7) and } (index 1) should be valid
|
||||
if !validSet[1] {
|
||||
t.Error("} should be valid after {")
|
||||
}
|
||||
if !validSet[6] && !validSet[7] {
|
||||
t.Error("STRING should be valid after { (for keys)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptSequence(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept {"key": "val"}
|
||||
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
|
||||
|
||||
for i, tokenID := range sequence {
|
||||
if !e.Accept(tokenID) {
|
||||
t.Fatalf("failed to accept token %d (%s) at position %d",
|
||||
tokenID, vocab[tokenID], i)
|
||||
}
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be in complete state after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineReset(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept some tokens
|
||||
e.Accept(0) // {
|
||||
e.Accept(1) // }
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
|
||||
// Reset
|
||||
e.Reset()
|
||||
|
||||
// Should be back to initial state
|
||||
if e.IsComplete() {
|
||||
t.Error("should not be complete after reset")
|
||||
}
|
||||
|
||||
// Should be able to accept new sequence
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept { after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineInvalidTokenRejection(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { first
|
||||
if !e.Accept(0) {
|
||||
t.Fatal("should accept {")
|
||||
}
|
||||
|
||||
// Now try to accept [ which is invalid after {
|
||||
// (After {, only STRING or } are valid)
|
||||
if e.Accept(2) { // [
|
||||
t.Error("should not accept [ after { (expecting STRING or })")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptString(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept using string directly
|
||||
if !e.AcceptString("{") {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.AcceptString("\"key\"") {
|
||||
t.Error("should accept string key")
|
||||
}
|
||||
if !e.AcceptString(":") {
|
||||
t.Error("should accept :")
|
||||
}
|
||||
if !e.AcceptString("123") {
|
||||
t.Error("should accept number")
|
||||
}
|
||||
if !e.AcceptString("}") {
|
||||
t.Error("should accept }")
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBackslashEscape(t *testing.T) {
|
||||
vocab := []string{`"`, `\`, "n", "a"}
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Valid escape: "\n"
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if !e.AcceptString("n") {
|
||||
t.Fatal("should accept escape code")
|
||||
}
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string end")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after escaped string")
|
||||
}
|
||||
|
||||
// Invalid escape: "\a"
|
||||
e.Reset()
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if e.AcceptString("a") {
|
||||
t.Error("should reject invalid escape code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineNegInfMask(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Verify negInfMask exists and has correct shape
|
||||
if e.negInfMask == nil {
|
||||
t.Fatal("negInfMask should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMaskCache(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Create test logits
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
|
||||
// Apply mask - should populate cache
|
||||
_ = e.ApplyMask(logits)
|
||||
|
||||
// Check cache was populated
|
||||
cacheSize := e.maskCache.size()
|
||||
if cacheSize == 0 {
|
||||
t.Error("mask cache should have at least one entry after ApplyMask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEmptyVocab(t *testing.T) {
|
||||
e := newTestEngine(t, []string{})
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 0 {
|
||||
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineLargeVocab(t *testing.T) {
|
||||
// Create a large vocabulary (simulating real model vocab)
|
||||
vocab := make([]string, 32000)
|
||||
for i := range vocab {
|
||||
vocab[i] = "token"
|
||||
}
|
||||
// Add some actual JSON tokens
|
||||
vocab[0] = "{"
|
||||
vocab[1] = "}"
|
||||
vocab[2] = "["
|
||||
vocab[3] = "]"
|
||||
vocab[4] = ":"
|
||||
vocab[5] = ","
|
||||
vocab[6] = "\"test\""
|
||||
vocab[7] = "123"
|
||||
vocab[8] = "true"
|
||||
vocab[9] = "false"
|
||||
vocab[10] = "null"
|
||||
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 32000 {
|
||||
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
|
||||
}
|
||||
|
||||
// Test that it still works correctly
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.Accept(1) { // }
|
||||
t.Error("should accept }")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
|
||||
func TestE2E_JSONDecoding(t *testing.T) {
|
||||
// Create a realistic vocabulary with JSON tokens
|
||||
vocab := []string{
|
||||
// Structural tokens
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
// Keywords
|
||||
"true", "false", "null",
|
||||
// Quoted strings
|
||||
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
|
||||
`"hello"`, `"world"`, `"test"`,
|
||||
// Numbers
|
||||
"0", "1", "2", "3", "42", "123", "-1", "-42",
|
||||
// Whitespace
|
||||
" ", "\n", "\t",
|
||||
// Multi-terminal tokens (span multiple JSON lexemes)
|
||||
`"key":`, `},`, `],`, `{"`, `["`,
|
||||
// Partial/invalid tokens (should be rejected)
|
||||
"invalid", "foo", "bar",
|
||||
}
|
||||
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
// Simple values
|
||||
{"empty object", []string{"{", "}"}, true},
|
||||
{"empty array", []string{"[", "]"}, true},
|
||||
{"true literal", []string{"true"}, true},
|
||||
{"null literal", []string{"null"}, true},
|
||||
{"number", []string{"42"}, true},
|
||||
{"negative number", []string{"-42"}, true},
|
||||
{"quoted string", []string{`"hello"`}, true},
|
||||
|
||||
// Objects
|
||||
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
|
||||
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
|
||||
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
|
||||
|
||||
// Arrays
|
||||
{"array of numbers", []string{"[", "42", "]"}, true},
|
||||
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
|
||||
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
|
||||
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
|
||||
|
||||
// Nested structures
|
||||
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
|
||||
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
|
||||
|
||||
// Invalid sequences
|
||||
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
|
||||
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
|
||||
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
|
||||
{"bare word", []string{"invalid"}, false}, // not valid JSON
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
// Process each token
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return // Already reported error
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
// For invalid sequences, we expect either rejection or incomplete
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
|
||||
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
|
||||
// Simple expression grammar: expr = term { ("+" | "-") term }
|
||||
// term = number | "(" expr ")"
|
||||
// number = digit { digit }
|
||||
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
|
||||
exprGrammar := `
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(exprGrammar, "expr")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary for expression tokens
|
||||
vocab := []string{
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"+", "-", "*", "/",
|
||||
"(", ")",
|
||||
// Multi-digit numbers as single tokens
|
||||
"10", "42", "100", "123",
|
||||
// Invalid tokens
|
||||
"x", "y", "invalid",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single digit", []string{"5"}, true},
|
||||
{"multi-digit", []string{"1", "2", "3"}, true},
|
||||
{"addition", []string{"1", "+", "2"}, true},
|
||||
{"subtraction", []string{"5", "-", "3"}, true},
|
||||
{"multiplication", []string{"2", "*", "3"}, true},
|
||||
{"division", []string{"8", "/", "2"}, true},
|
||||
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
|
||||
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
|
||||
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
|
||||
|
||||
// Invalid
|
||||
{"just operator", []string{"+"}, false},
|
||||
{"double operator", []string{"1", "+", "+", "2"}, false},
|
||||
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
|
||||
{"variable", []string{"x"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
|
||||
func TestE2E_IdentifierGrammar(t *testing.T) {
|
||||
// Identifier grammar using character ranges
|
||||
identGrammar := `
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(identGrammar, "ident")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse identifier grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with letters and digits
|
||||
vocab := []string{
|
||||
"a", "b", "c", "x", "y", "z",
|
||||
"A", "B", "C", "X", "Y", "Z",
|
||||
"_",
|
||||
"0", "1", "2", "9",
|
||||
// Multi-char tokens
|
||||
"foo", "bar", "myVar", "test123",
|
||||
// Invalid starting chars
|
||||
"1abc", "123",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single letter", []string{"a"}, true},
|
||||
{"uppercase", []string{"A"}, true},
|
||||
{"underscore", []string{"_"}, true},
|
||||
{"multi-letter", []string{"a", "b", "c"}, true},
|
||||
{"letter then digit", []string{"x", "1"}, true},
|
||||
{"underscore prefix", []string{"_", "a", "1"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass && allAccepted && !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
|
||||
func TestE2E_UnicodeRange(t *testing.T) {
|
||||
greekGrammar := `
|
||||
greek = "α" … "ω" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(greekGrammar, "greek")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse unicode grammar: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{"α", "β", "ω", "a"}
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
if !engine.AcceptString("β") {
|
||||
t.Error("should accept beta")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete after single rune")
|
||||
}
|
||||
|
||||
engine.Reset()
|
||||
if engine.AcceptString("a") {
|
||||
t.Error("should reject ASCII outside unicode range")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
|
||||
func TestE2E_NondeterminismPreserved(t *testing.T) {
|
||||
// This grammar has nondeterminism: "ab" could be parsed as
|
||||
// a single token or as two tokens "a" "b"
|
||||
ambiguousGrammar := `
|
||||
start = item item .
|
||||
item = "a" | "b" | "ab" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(ambiguousGrammar, "start")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with both single and combined tokens
|
||||
vocab := []string{"a", "b", "ab"}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Test: "ab" "a" should be valid (ab as first item, a as second)
|
||||
t.Run("ab then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a after ab")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then ab", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a")
|
||||
}
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab after a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept first a")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept second a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,614 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package grammar provides GPU-accelerated constrained decoding using MLX.
|
||||
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
|
||||
// For JSON Schema conversion, see the grammar/schema subpackage.
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/ebnf"
|
||||
)
|
||||
|
||||
// stackSymbol represents a symbol that can be pushed onto the pda stack.
|
||||
type stackSymbol int
|
||||
|
||||
const (
|
||||
stackEmpty stackSymbol = iota
|
||||
// Additional stack symbols will be generated per-grammar
|
||||
)
|
||||
|
||||
// state represents a pda state.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateError state = -1
|
||||
stateStart state = 0
|
||||
stateAccept state = 1
|
||||
// Additional states will be generated per-grammar
|
||||
)
|
||||
|
||||
// transition represents a pda transition.
|
||||
// On input matching Pattern, from FromState with stackTop:
|
||||
// - Move to ToState
|
||||
// - Pop StackPop symbols, push StackPush symbols
|
||||
type transition struct {
|
||||
FromState state
|
||||
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
|
||||
Pattern string // Input pattern to match (token or character class)
|
||||
ToState state
|
||||
StackPop int // Number of symbols to pop
|
||||
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
|
||||
}
|
||||
|
||||
// pda represents a compiled pushdown automaton.
|
||||
type pda struct {
|
||||
States int // Total number of states
|
||||
StackSymbols int // Total number of stack symbols
|
||||
StartState state // Initial state
|
||||
AcceptStates map[state]bool // Set of accepting states
|
||||
Transitions map[state][]transition // Transitions indexed by from-state
|
||||
|
||||
// For token-level matching
|
||||
Terminals []string // All terminal symbols (patterns to match)
|
||||
}
|
||||
|
||||
// newPDA creates an empty pda.
|
||||
func newPDA() *pda {
|
||||
return &pda{
|
||||
States: 2, // Error and Start
|
||||
StackSymbols: 1, // Empty
|
||||
StartState: stateStart,
|
||||
AcceptStates: make(map[state]bool),
|
||||
Transitions: make(map[state][]transition),
|
||||
Terminals: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// addState adds a new state and returns its ID.
|
||||
func (p *pda) addState() state {
|
||||
s := state(p.States)
|
||||
p.States++
|
||||
return s
|
||||
}
|
||||
|
||||
// addStackSymbol adds a new stack symbol and returns its ID.
|
||||
func (p *pda) addStackSymbol() stackSymbol {
|
||||
s := stackSymbol(p.StackSymbols)
|
||||
p.StackSymbols++
|
||||
return s
|
||||
}
|
||||
|
||||
// addTransition adds a transition to the pda.
|
||||
func (p *pda) addTransition(t transition) {
|
||||
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
|
||||
}
|
||||
|
||||
// addTerminal registers a terminal pattern and returns its index.
|
||||
func (p *pda) addTerminal(pattern string) int {
|
||||
for i, t := range p.Terminals {
|
||||
if t == pattern {
|
||||
return i
|
||||
}
|
||||
}
|
||||
p.Terminals = append(p.Terminals, pattern)
|
||||
return len(p.Terminals) - 1
|
||||
}
|
||||
|
||||
// compiler compiles EBNF grammars to PDAs.
|
||||
type compiler struct {
|
||||
grammar ebnf.Grammar
|
||||
pda *pda
|
||||
|
||||
// Maps production names to their entry/exit states
|
||||
prodEntry map[string]state
|
||||
prodExit map[string]state
|
||||
}
|
||||
|
||||
// compile parses an EBNF grammar and compiles it to a pda.
|
||||
func compile(name string, src io.Reader, start string) (*pda, error) {
|
||||
grammar, err := ebnf.Parse(name, src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse grammar: %w", err)
|
||||
}
|
||||
|
||||
if err := ebnf.Verify(grammar, start); err != nil {
|
||||
return nil, fmt.Errorf("verify grammar: %w", err)
|
||||
}
|
||||
|
||||
c := &compiler{
|
||||
grammar: grammar,
|
||||
pda: newPDA(),
|
||||
prodEntry: make(map[string]state),
|
||||
prodExit: make(map[string]state),
|
||||
}
|
||||
|
||||
// Create entry/exit states for each production
|
||||
for name := range grammar {
|
||||
c.prodEntry[name] = c.pda.addState()
|
||||
c.prodExit[name] = c.pda.addState()
|
||||
}
|
||||
|
||||
// compile each production
|
||||
for name, prod := range grammar {
|
||||
if err := c.compileProduction(name, prod); err != nil {
|
||||
return nil, fmt.Errorf("compile production %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set start state to entry of start production
|
||||
if entry, ok := c.prodEntry[start]; ok {
|
||||
// Add epsilon transition from pda start to grammar start
|
||||
c.pda.addTransition(transition{
|
||||
FromState: stateStart,
|
||||
Pattern: "", // epsilon
|
||||
ToState: entry,
|
||||
})
|
||||
} else {
|
||||
return nil, fmt.Errorf("start production %q not found", start)
|
||||
}
|
||||
|
||||
// Mark exit of start production as accepting
|
||||
if exit, ok := c.prodExit[start]; ok {
|
||||
c.pda.AcceptStates[exit] = true
|
||||
}
|
||||
|
||||
return c.pda, nil
|
||||
}
|
||||
|
||||
// compileString is a convenience function to compile from a string.
|
||||
func compileString(grammar string, start string) (*pda, error) {
|
||||
return compile("grammar", strings.NewReader(grammar), start)
|
||||
}
|
||||
|
||||
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
|
||||
entry := c.prodEntry[name]
|
||||
exit := c.prodExit[name]
|
||||
|
||||
return c.compileExpr(prod.Expr, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
|
||||
switch e := expr.(type) {
|
||||
case *ebnf.Name:
|
||||
return c.compileName(e, entry, exit)
|
||||
case *ebnf.Token:
|
||||
return c.compileToken(e, entry, exit)
|
||||
case ebnf.Sequence:
|
||||
return c.compileSequence(e, entry, exit)
|
||||
case ebnf.Alternative:
|
||||
return c.compileAlternative(e, entry, exit)
|
||||
case *ebnf.Option:
|
||||
return c.compileOption(e, entry, exit)
|
||||
case *ebnf.Repetition:
|
||||
return c.compileRepetition(e, entry, exit)
|
||||
case *ebnf.Group:
|
||||
return c.compileExpr(e.Body, entry, exit)
|
||||
case *ebnf.Range:
|
||||
return c.compileRange(e, entry, exit)
|
||||
case nil:
|
||||
// Empty production - direct epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
|
||||
// Reference to another production
|
||||
prodName := n.String
|
||||
|
||||
prodEntry, ok := c.prodEntry[prodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined production: %s", prodName)
|
||||
}
|
||||
prodExit := c.prodExit[prodName]
|
||||
// Use a unique stack symbol per call site so returns are unambiguous.
|
||||
stackSym := c.pda.addStackSymbol()
|
||||
|
||||
// Push return address, go to production entry
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "", // epsilon
|
||||
ToState: prodEntry,
|
||||
StackPush: []stackSymbol{stackSym},
|
||||
})
|
||||
|
||||
// On production exit, pop and return
|
||||
c.pda.addTransition(transition{
|
||||
FromState: prodExit,
|
||||
stackTop: stackSym,
|
||||
Pattern: "", // epsilon
|
||||
ToState: exit,
|
||||
StackPop: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
|
||||
// terminal symbol - add transition that consumes this token
|
||||
pattern := t.String
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
|
||||
if len(seq) == 0 {
|
||||
// Empty sequence - epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain: entry -> s1 -> s2 -> ... -> exit
|
||||
current := entry
|
||||
for i, expr := range seq {
|
||||
var next state
|
||||
if i == len(seq)-1 {
|
||||
next = exit
|
||||
} else {
|
||||
next = c.pda.addState()
|
||||
}
|
||||
|
||||
if err := c.compileExpr(expr, current, next); err != nil {
|
||||
return err
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
|
||||
// Each alternative goes from entry to exit
|
||||
for _, expr := range alt {
|
||||
if err := c.compileExpr(expr, entry, exit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
|
||||
// Optional: can skip (epsilon) or take the body
|
||||
|
||||
// Epsilon transition (skip)
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Or take the body
|
||||
return c.compileExpr(opt.Body, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
|
||||
// Repetition {body}: zero or more
|
||||
// entry -> exit (skip)
|
||||
// entry -> body -> entry (loop back)
|
||||
|
||||
// Skip transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Loop: entry -> (body) -> entry
|
||||
return c.compileExpr(rep.Body, entry, entry)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
|
||||
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
|
||||
begin := strings.Trim(r.Begin.String, "\"")
|
||||
end := strings.Trim(r.End.String, "\"")
|
||||
|
||||
// Unescape bounds first (so "\u03b1" works)
|
||||
beginUnesc, err := unescapeLiteral(begin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range begin: %w", err)
|
||||
}
|
||||
endUnesc, err := unescapeLiteral(end)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range end: %w", err)
|
||||
}
|
||||
|
||||
// Validate as single runes (not bytes) for Unicode support
|
||||
beginRunes := []rune(beginUnesc)
|
||||
endRunes := []rune(endUnesc)
|
||||
if len(beginRunes) != 1 || len(endRunes) != 1 {
|
||||
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
|
||||
}
|
||||
|
||||
// Use unescaped rune strings in pattern (consistent with matcher)
|
||||
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtime represents a pda execution instance.
|
||||
type runtime struct {
|
||||
pda *pda
|
||||
state state
|
||||
stack []stackSymbol
|
||||
}
|
||||
|
||||
// newRuntime creates a new pda runtime.
|
||||
func newRuntime(pda *pda) *runtime {
|
||||
return &runtime{
|
||||
pda: pda,
|
||||
state: pda.StartState,
|
||||
stack: make([]stackSymbol, 0, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// stackTop returns the top of the stack, or stackEmpty if empty.
|
||||
func (r *runtime) stackTop() stackSymbol {
|
||||
if len(r.stack) == 0 {
|
||||
return stackEmpty
|
||||
}
|
||||
return r.stack[len(r.stack)-1]
|
||||
}
|
||||
|
||||
// isAccepting returns true if we can reach an accepting state via epsilon transitions
|
||||
// with an empty stack.
|
||||
func (r *runtime) isAccepting() bool {
|
||||
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if r.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
for _, t := range r.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
|
||||
// Check stack constraint
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simulate stack operations
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 && len(newStack) >= t.StackPop {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if r.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset resets the runtime to initial state.
|
||||
func (r *runtime) Reset() {
|
||||
r.state = r.pda.StartState
|
||||
r.stack = r.stack[:0]
|
||||
}
|
||||
|
||||
// validInputs returns all valid input patterns from current state.
|
||||
func (r *runtime) validInputs() []string {
|
||||
var valid []string
|
||||
seen := make(map[string]bool)
|
||||
visited := make(map[stateStackKey]bool)
|
||||
|
||||
// Make a copy of the stack for simulation
|
||||
simStack := make([]stackSymbol, len(r.stack))
|
||||
copy(simStack, r.stack)
|
||||
|
||||
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
|
||||
return valid
|
||||
}
|
||||
|
||||
// stateStackKey is used to detect cycles in epsilon closure
|
||||
type stateStackKey struct {
|
||||
state state
|
||||
stackSig string
|
||||
}
|
||||
|
||||
func stackSignature(stack []stackSymbol) string {
|
||||
if len(stack) == 0 {
|
||||
return ""
|
||||
}
|
||||
buf := make([]byte, len(stack)*8)
|
||||
for i, sym := range stack {
|
||||
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
|
||||
// Get stack top for comparisons
|
||||
stackTop := stackEmpty
|
||||
if len(simStack) > 0 {
|
||||
stackTop = simStack[len(simStack)-1]
|
||||
}
|
||||
|
||||
// Check for cycles to avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[state]
|
||||
|
||||
for _, t := range transitions {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - simulate stack operations
|
||||
newStack := make([]stackSymbol, len(simStack))
|
||||
copy(newStack, simStack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
if len(newStack) < t.StackPop {
|
||||
continue // Can't pop, skip this transition
|
||||
}
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
|
||||
} else {
|
||||
// terminal - add if not seen
|
||||
if !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*valid = append(*valid, t.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchesPattern checks if input matches a pattern.
|
||||
// Patterns can be:
|
||||
// - Exact strings: "a", "{", "true"
|
||||
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
|
||||
func matchesPattern(input, pattern string) bool {
|
||||
// Exact match
|
||||
if input == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for character range pattern [X-Y]
|
||||
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
|
||||
if len(input) != 1 {
|
||||
return false
|
||||
}
|
||||
ch := input[0]
|
||||
low := pattern[1]
|
||||
high := pattern[3]
|
||||
return ch >= low && ch <= high
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Accept tries to accept an input, returning true if successful.
|
||||
func (r *runtime) Accept(input string) bool {
|
||||
return r.accept(input, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[r.state]
|
||||
|
||||
// First, process any epsilon transitions to reach a state that can accept input
|
||||
// This is a simplified version - full implementation would need epsilon closure
|
||||
for _, t := range transitions {
|
||||
if matchesPattern(input, t.Pattern) {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply transition
|
||||
r.applyTransition(t)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try epsilon transitions first
|
||||
for _, t := range transitions {
|
||||
if t.Pattern == "" {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Save state for backtracking
|
||||
oldState := r.state
|
||||
oldStack := make([]stackSymbol, len(r.stack))
|
||||
copy(oldStack, r.stack)
|
||||
|
||||
r.applyTransition(t)
|
||||
|
||||
if r.accept(input, visited) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Backtrack
|
||||
r.state = oldState
|
||||
r.stack = oldStack
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *runtime) applyTransition(t transition) {
|
||||
// Pop
|
||||
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
|
||||
r.stack = r.stack[:len(r.stack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
r.stack = append(r.stack, t.StackPush...)
|
||||
|
||||
// Move to new state
|
||||
r.state = t.ToState
|
||||
}
|
||||
@@ -1,540 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileSimpleGrammar(t *testing.T) {
|
||||
// Simple grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
if pda == nil {
|
||||
t.Fatal("pda is nil")
|
||||
}
|
||||
|
||||
// Should have terminals "a" and "b"
|
||||
if len(pda.Terminals) != 2 {
|
||||
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
|
||||
}
|
||||
|
||||
// Test runtime
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Should accept "a" then "b"
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be in accepting state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileAlternative(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Test accepting "a"
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'a'")
|
||||
}
|
||||
|
||||
// Test accepting "b"
|
||||
rt.Reset()
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// Test rejecting "c"
|
||||
rt.Reset()
|
||||
if rt.Accept("c") {
|
||||
t.Error("should not accept 'c'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRepetition(t *testing.T) {
|
||||
// Grammar: S = {"a"} .
|
||||
grammar := `S = {"a"} .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty should be accepted (zero repetitions)
|
||||
rt := newRuntime(pda)
|
||||
if !rt.isAccepting() {
|
||||
t.Error("empty should be accepting")
|
||||
}
|
||||
|
||||
// "a" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept first 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after one 'a'")
|
||||
}
|
||||
|
||||
// "aa" should be accepted
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept second 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after two 'a's")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileOption(t *testing.T) {
|
||||
// Grammar: S = ["a"] "b" .
|
||||
grammar := `S = ["a"] "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "b" alone should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' alone")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// "ab" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' after 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'ab'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRecursive(t *testing.T) {
|
||||
// Grammar with recursion: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "x" should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'x'")
|
||||
}
|
||||
|
||||
// "(x)" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x' inside parens")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x)'")
|
||||
}
|
||||
|
||||
// "((x))" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept first '('")
|
||||
}
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept second '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept first ')'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept second ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '((x))'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidInputs(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
valid := rt.validInputs()
|
||||
|
||||
// Should have both "a" and "b" as valid
|
||||
hasA, hasB := false, false
|
||||
for _, v := range valid {
|
||||
if v == "a" {
|
||||
hasA = true
|
||||
}
|
||||
if v == "b" {
|
||||
hasB = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasA {
|
||||
t.Error("'a' should be valid input")
|
||||
}
|
||||
if !hasB {
|
||||
t.Error("'b' should be valid input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsAfterAccept tests that validInputs returns correct values
|
||||
// after accepting tokens, ensuring proper stack simulation.
|
||||
func TestValidInputsAfterAccept(t *testing.T) {
|
||||
// Grammar: S = "a" "b" "c" .
|
||||
grammar := `S = "a" "b" "c" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "a" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "a" {
|
||||
t.Errorf("initially expected only 'a', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "a", only "b" should be valid
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "b" {
|
||||
t.Errorf("after 'a', expected only 'b', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "b", only "c" should be valid
|
||||
if !rt.Accept("b") {
|
||||
t.Fatal("failed to accept 'b'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "c" {
|
||||
t.Errorf("after 'ab', expected only 'c', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsWithRepetitionInProduction tests the critical case where
|
||||
// a repetition exists inside a called production. This requires proper
|
||||
// stack simulation to determine when closing symbols are valid.
|
||||
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
|
||||
// Grammar similar to JSON:
|
||||
// S = "(" items ")" .
|
||||
// items = item { "," item } .
|
||||
// item = "x" .
|
||||
grammar := `
|
||||
S = "(" items ")" .
|
||||
items = item { "," item } .
|
||||
item = "x" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "(" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "(" {
|
||||
t.Errorf("initially expected only '(', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept "("
|
||||
if !rt.Accept("(") {
|
||||
t.Fatal("failed to accept '('")
|
||||
}
|
||||
// After "(", should be able to accept "x" (item)
|
||||
valid = rt.validInputs()
|
||||
hasX := false
|
||||
for _, v := range valid {
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasX {
|
||||
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept first item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept 'x'")
|
||||
}
|
||||
// After "(x", should be able to accept "," (more items) OR ")" (end)
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose := false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept comma for another item
|
||||
if !rt.Accept(",") {
|
||||
t.Fatal("failed to accept ','")
|
||||
}
|
||||
// After "(x,", should only be able to accept "x" (next item)
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "x" {
|
||||
t.Errorf("after '(x,', expected only 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept second item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept second 'x'")
|
||||
}
|
||||
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
|
||||
// This tests the stack simulation fix - we need to properly
|
||||
// follow epsilon transitions through the production call stack.
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose = false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Close with ")"
|
||||
if !rt.Accept(")") {
|
||||
t.Fatal("failed to accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x,x)'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
|
||||
func TestValidInputsNestedCalls(t *testing.T) {
|
||||
// Grammar: A = "start" B "end" . B = "middle" .
|
||||
grammar := `
|
||||
A = "start" B "end" .
|
||||
B = "middle" .
|
||||
`
|
||||
pda, err := compileString(grammar, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// After "start", should accept "middle" (from B)
|
||||
rt.Accept("start")
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "middle" {
|
||||
t.Errorf("after 'start', expected 'middle', got %v", valid)
|
||||
}
|
||||
|
||||
// After "start middle", should accept "end"
|
||||
rt.Accept("middle")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "end" {
|
||||
t.Errorf("after 'start middle', expected 'end', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnAddressDisambiguation(t *testing.T) {
|
||||
// Grammar where the same production is called from different contexts:
|
||||
// S = A "x" | "c" A "y" .
|
||||
// A = "a" .
|
||||
grammar := `
|
||||
S = A "x" | "c" A "y" .
|
||||
A = "a" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
if !rt.Accept("c") {
|
||||
t.Fatal("failed to accept 'c'")
|
||||
}
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "y" {
|
||||
t.Errorf("after 'ca', expected only 'y', got %v", valid)
|
||||
}
|
||||
|
||||
rt.Reset()
|
||||
rt.Accept("c")
|
||||
rt.Accept("a")
|
||||
if rt.Accept("x") {
|
||||
t.Error("should not accept 'x' after 'ca'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
|
||||
// which heavily exercise the stack simulation.
|
||||
func TestValidInputsRecursiveWithStack(t *testing.T) {
|
||||
// Grammar: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially: "(" or "x" should be valid
|
||||
valid := rt.validInputs()
|
||||
hasParen, hasX := false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("initially expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "(": "(" or "x" should be valid (nested S)
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((": "(" or "x" should still be valid
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x": only ")" should be valid
|
||||
rt.Accept("x")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x', expected only ')', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x)": only ")" should be valid (closing outer)
|
||||
rt.Accept(")")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x)', expected only ')', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRejectionAfterValid tests that invalid inputs are rejected
|
||||
// at various points in the grammar.
|
||||
func TestRejectionAfterValid(t *testing.T) {
|
||||
// Grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// "b" should be rejected initially
|
||||
if rt.Accept("b") {
|
||||
t.Error("'b' should be rejected initially")
|
||||
}
|
||||
|
||||
// Accept "a"
|
||||
rt.Accept("a")
|
||||
|
||||
// "a" should be rejected after "a"
|
||||
if rt.Accept("a") {
|
||||
t.Error("'a' should be rejected after 'a'")
|
||||
}
|
||||
|
||||
// "c" should be rejected (not in grammar)
|
||||
if rt.Accept("c") {
|
||||
t.Error("'c' should be rejected (not in grammar)")
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
# Example Grammars
|
||||
|
||||
This directory contains example EBNF grammars for constrained decoding.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
go run -tags mlx ./x/imagegen/cmd/engine/ \
|
||||
-model /path/to/model \
|
||||
-prompt "Your prompt" \
|
||||
-grammar x/grammar/grammars/json.ebnf \
|
||||
-grammar-start value
|
||||
```
|
||||
|
||||
## Available Grammars
|
||||
|
||||
| File | Start Rule | Description |
|
||||
|------|------------|-------------|
|
||||
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
|
||||
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
|
||||
| `identifier.ebnf` | `ident` | Programming language identifiers |
|
||||
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
|
||||
| `list.ebnf` | `list` | Comma-separated word list |
|
||||
| `yesno.ebnf` | `response` | Simple yes/no responses |
|
||||
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
|
||||
| `email.ebnf` | `email` | Basic email addresses |
|
||||
| `phone.ebnf` | `phone` | US phone numbers |
|
||||
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
|
||||
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
|
||||
|
||||
## Grammar Syntax
|
||||
|
||||
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
|
||||
|
||||
The grammars use EBNF notation:
|
||||
|
||||
- `=` defines a production rule
|
||||
- `|` is alternation (or)
|
||||
- `{ }` is repetition (zero or more)
|
||||
- `[ ]` is optional (zero or one)
|
||||
- `" "` is a literal string
|
||||
- `…` is a character range (e.g., `"a" … "z"`)
|
||||
- `.` ends a production
|
||||
|
||||
## Writing Custom Grammars
|
||||
|
||||
1. Define your grammar in a `.ebnf` file
|
||||
2. Choose a start rule name
|
||||
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
|
||||
|
||||
Example custom grammar for RGB colors:
|
||||
|
||||
```ebnf
|
||||
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
```
|
||||
@@ -1,7 +0,0 @@
|
||||
expr = term { " OR " term } .
|
||||
term = factor { " AND " factor } .
|
||||
factor = "NOT " factor | atom | "(" expr ")" .
|
||||
atom = "true" | "false" | ident .
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
digit = "0" … "9" .
|
||||
@@ -1,6 +0,0 @@
|
||||
date = year "-" month "-" day .
|
||||
year = digit digit digit digit .
|
||||
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
|
||||
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
|
||||
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
@@ -1,5 +0,0 @@
|
||||
email = localpart "@" domain .
|
||||
localpart = word { "." word } .
|
||||
domain = word { "." word } .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
@@ -1,7 +0,0 @@
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
@@ -1,4 +0,0 @@
|
||||
color = "#" ( hex6 | hex3 ) .
|
||||
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hex3 = hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user